1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/gemm_rewriter.h"
17
18 #include <memory>
19 #include <utility>
20 #include <vector>
21
22 #include "absl/algorithm/container.h"
23 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
24 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
25 #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
26 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
28 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
29 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
30 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
31 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
32 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
33 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
34 #include "tensorflow/compiler/xla/status_macros.h"
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/compiler/xla/xla_data.pb.h"
37 #include "tensorflow/core/lib/core/errors.h"
38 #include "tensorflow/stream_executor/lib/statusor.h"
39
40 namespace xla {
41 namespace gpu {
42 namespace {
43
44 namespace m = match;
45
46 // Give this instruction a more useful name than "custom-call.42".
SetName(HloModule * module,HloInstruction * gemm)47 Status SetName(HloModule *module, HloInstruction *gemm) {
48 if (IsCublasLtMatmul(*gemm)) {
49 module->SetAndUniquifyInstrName(gemm, "cublas-lt-matmul");
50 return OkStatus();
51 }
52
53 GemmBackendConfig config;
54 TF_ASSIGN_OR_RETURN(config, gemm->backend_config<GemmBackendConfig>());
55 const DotDimensionNumbers &dot_dims = config.dot_dimension_numbers();
56 bool is_batch_dot = !dot_dims.lhs_batch_dimensions().empty() ||
57 !dot_dims.rhs_batch_dimensions().empty();
58
59 module->SetAndUniquifyInstrName(
60 gemm, is_batch_dot ? "cublas-batch-gemm" : "cublas-gemm");
61 return OkStatus();
62 }
63
64 // If the bias is a sequence of ops that depend only on broadcasts of
65 // constants, materialize the bias if it's small.
66 //
67 // Normally the constant-folding pass would materialize the bias if it is
68 // calculated entirely from constants. But if the bias is a broadcast of a
69 // constant, constant-folding won't expand the broadcast, on the theory that
70 // folding broadcasts of constants causes us to consume more memory and can
71 // actually make things slower (because any op which reads the constant has
72 // to read more memory).
73 //
74 // OTOH in our case, we don't want to run an op that just broadcasts a
75 // constant so we can fuse it into this gemm. That would defeat the whole
76 // purpose of this fusion, which is to launch fewer kernels. So if we can,
77 // we expand out this constant ourselves.
78 //
79 // TODO(b/192499646): Even better would be to use cublasLT to fuse the
80 // broadcasted bias, if it supports that fusion efficiently.
MaybeConstantFoldBias(HloInstruction * bias)81 HloInstruction *MaybeConstantFoldBias(HloInstruction *bias) {
82 // This limit was not chosen carefully.
83 constexpr int kMaxMaterializeBiasBytes = 8 * 1024 * 1024;
84
85 // Don't fold broadcasts of scalars -- algsimp will just collapse it again.
86 auto is_nonscalar = [](const HloInstruction *instr) {
87 return !ShapeUtil::IsEffectiveScalar(instr->shape());
88 };
89
90 // For now, only fold broadcast(constant) or
91 // reshape/transpose/bitcast(broadcast(constant)). This lets us avoid the
92 // complexity in the constant-folding pass about what is and isn't legal to
93 // fold.
94 auto broadcast_of_nonscalar =
95 m::Broadcast(m::Constant().WithPredicate(is_nonscalar));
96
97 if (ShapeUtil::ByteSizeOf(bias->shape()) <= kMaxMaterializeBiasBytes &&
98 (Match(bias, broadcast_of_nonscalar) ||
99 Match(bias, m::Reshape(broadcast_of_nonscalar)) ||
100 Match(bias, m::Transpose(broadcast_of_nonscalar)) ||
101 Match(bias, m::Bitcast(broadcast_of_nonscalar)))) {
102 HloEvaluator evaluator(/*max_loop_iterations=*/0);
103 Literal result;
104 if (evaluator.TryEvaluate(
105 bias, &result,
106 /*recursively_evaluate_nonconstant_operands=*/true)) {
107 return bias->parent()->AddInstruction(
108 HloInstruction::CreateConstant(std::move(result)));
109 }
110 }
111
112 return bias;
113 }
114
115 // The rewriting proceeds in a bottom-up way:
116 //
117 // (kDot A B) is rewritten into a (kCustomCall:gemm A B)
118 //
119 // (kMultiply (kCustomCall:gemm A B) C) is folding C (provided it's a constant)
120 // into an alpha parameter of the custom call.
121 //
122 // (kAdd (kCustomCall:gemm A B) C) is rewritten into (kCustomCall:gemm A B C),
123 // where the "beta" parameter is set to 1 (provided it was zero before,
124 // and provided C has no other users).
125 // We then guide the buffer assignment to alias the buffer of the custom call
126 // and C.
127 class GemmRewriterVisitor : public DfsHloRewriteVisitor {
128 public:
HandleDot(HloInstruction * instr)129 Status HandleDot(HloInstruction *instr) override {
130 if (IsMatrixMultiplication(*instr)) {
131 CHECK(!instr->IsRank2Transpose());
132 HloInstruction *lhs = instr->mutable_operand(0);
133 HloInstruction *rhs = instr->mutable_operand(1);
134 CHECK(!lhs->IsRank2Transpose());
135 CHECK(!rhs->IsRank2Transpose());
136 const Shape &output_shape = instr->shape();
137
138 const char *const target =
139 instr->GetModule()->config().debug_options().xla_gpu_enable_cublaslt()
140 ? kCublasLtMatmulCallTarget
141 : kGemmCallTarget;
142
143 std::unique_ptr<HloInstruction> gemm_call =
144 HloInstruction::CreateCustomCall(output_shape, {lhs, rhs}, target);
145 GemmBackendConfig gemm_config;
146 gemm_config.set_alpha_real(1.0);
147 gemm_config.set_alpha_imag(0.0);
148 gemm_config.set_beta(0.0);
149 *gemm_config.mutable_dot_dimension_numbers() =
150 instr->dot_dimension_numbers();
151 *gemm_config.mutable_precision_config() = instr->precision_config();
152
153 TF_RETURN_IF_ERROR(gemm_call->set_backend_config(gemm_config));
154 TF_RETURN_IF_ERROR(SetName(instr->GetModule(), gemm_call.get()));
155 TF_RETURN_IF_ERROR(
156 ReplaceWithNewInstruction(instr, std::move(gemm_call)));
157 }
158 return OkStatus();
159 }
160
HandleMultiply(HloInstruction * instr)161 Status HandleMultiply(HloInstruction *instr) override {
162 HloInstruction *alpha, *existing_gemm;
163 if (Match(instr,
164 m::MultiplyAnyOrder(
165 m::Op(&existing_gemm).WithCustomCallTarget(kGemmCallTarget),
166 m::Broadcast(m::ConstantScalar(&alpha))))) {
167 TF_ASSIGN_OR_RETURN(auto config,
168 existing_gemm->backend_config<GemmBackendConfig>());
169
170 // Do not fuse alpha into S32 GEMM, as they only support fixed values for
171 // alpha/beta.
172 if (existing_gemm->shape().element_type() == S32) {
173 return OkStatus();
174 }
175
176 if (config.beta() == 0.0 && existing_gemm->user_count() == 1) {
177 complex128 prev_alpha = {config.alpha_real(), config.alpha_imag()};
178 complex128 new_alpha =
179 *alpha->literal().GetAsComplex128({}) * prev_alpha;
180 config.set_alpha_real(new_alpha.real());
181 config.set_alpha_imag(new_alpha.imag());
182 TF_RETURN_IF_ERROR(existing_gemm->set_backend_config(config));
183 TF_RETURN_IF_ERROR(ReplaceInstruction(instr, existing_gemm));
184 }
185 }
186 return OkStatus();
187 }
188
HandleAdd(HloInstruction * instr)189 Status HandleAdd(HloInstruction *instr) override {
190 HloInstruction *bias, *existing_gemm;
191
192 // First, try to match vector bias add, so we might elide the broadcast.
193 if (Match(instr, m::AddAnyOrder(
194 m::Op(&existing_gemm)
195 .WithCustomCallTarget(kCublasLtMatmulCallTarget),
196 m::Broadcast(&bias, m::Op())))) {
197 TF_ASSIGN_OR_RETURN(bool was_fused,
198 FuseVectorBiasAdd(instr, bias, existing_gemm));
199 if (was_fused) return OkStatus();
200 }
201
202 // add(bitcast(gemm(a, b)), bias) ->
203 // bitcast(add(gemm(a, b), bitcast(bias))) ->
204 // bitcast(gemm(a, b, bitcast(bias))) (later down in this function).
205 //
206 // We see this idiom in models that contain batch-dots, where we cast
207 // between a rank-2 shape for non-batch dots and a higher-rank shape for
208 // batch-dots.
209 //
210 // The last stage of the transform may fail (because of any of the checks in
211 // FuseBiasedGemm), but if so that's okay -- we'll have done a useless
212 // transformation, but it doesn't hurt anything.
213 if (Match(instr, m::AddAnyOrder(
214 m::Bitcast(m::Op(&existing_gemm)
215 .WithCustomCallTarget(kGemmCallTarget)
216 .WithOneUser())
217 .WithOneUser(),
218 m::Op(&bias)))) {
219 HloInstruction *new_bitcast =
220 MakeBitcastHlo(bias, existing_gemm->shape(), &bias->metadata());
221 TF_ASSIGN_OR_RETURN(HloInstruction * new_add,
222 MakeBinaryHlo(HloOpcode::kAdd, existing_gemm,
223 new_bitcast, &bias->metadata()));
224 TF_RETURN_IF_ERROR(
225 ReplaceInstruction(instr, MakeBitcastHlo(new_add, instr->shape())));
226
227 // Continue below transforming new_add.
228 instr = new_add;
229 }
230
231 if (Match(instr, m::AddAnyOrder(
232 m::Op(&existing_gemm)
233 .WithCustomCallTarget(
234 {kGemmCallTarget, kCublasLtMatmulCallTarget}),
235 m::Op(&bias)))) {
236 return FuseMatrixBiasAdd(instr, bias, existing_gemm);
237 }
238
239 return Status::OK();
240 }
241
HandleConvert(HloInstruction * instr)242 Status HandleConvert(HloInstruction *instr) override {
243 HloInstruction *bias, *existing_gemm;
244 if (Match(
245 instr,
246 m::Convert(m::AddAnyOrder(
247 m::Convert(m::Op(&existing_gemm)
248 .WithCustomCallTarget(kGemmCallTarget)
249 .WithElementType(BF16)),
250 m::Convert(m::Op(&bias).WithElementType(BF16))))
251 .WithElementType(BF16))) {
252 return FuseMatrixBiasAdd(instr, bias, existing_gemm);
253 }
254 return OkStatus();
255 }
256
FuseMatrixBiasAdd(HloInstruction * instr,HloInstruction * bias,HloInstruction * gemm)257 Status FuseMatrixBiasAdd(HloInstruction *instr, HloInstruction *bias,
258 HloInstruction *gemm) {
259 TF_RET_CHECK(bias->shape() == gemm->shape());
260
261 // Do not fuse bias into S32 GEMM, as for this datatype cuBLAS only
262 // supports fixed values for alpha/beta.
263 if (gemm->shape().element_type() == S32) {
264 return OkStatus();
265 }
266
267 // BLAS GeMM overwrites bias matrix, so fusion is only possible if the GeMM
268 // is the only user. cublasLt matmul can operate out-of-place.
269 bool can_fuse_bias = (bias->user_count() == 1) || IsCublasLtMatmul(*gemm);
270
271 auto config = gemm->backend_config<GemmBackendConfig>().ValueOrDie();
272
273 // It is possible to fuse into a cublasLt matmul that already has a vector
274 // bias, but no other epilogue will commute with the matrix bias add.
275 bool supported_epilogue =
276 ((config.epilogue() == GemmBackendConfig::DEFAULT) ||
277 (config.epilogue() == GemmBackendConfig::BIAS));
278
279 if ((config.beta() != 0) || !can_fuse_bias || (gemm->user_count() != 1) ||
280 !supported_epilogue) {
281 return OkStatus();
282 }
283
284 config.set_beta(1.0);
285
286 std::vector<HloInstruction *> operands(gemm->operands().begin(),
287 gemm->operands().end());
288 operands.insert(operands.begin() + 2, MaybeConstantFoldBias(bias));
289
290 std::unique_ptr<HloInstruction> fused_op =
291 gemm->CloneWithNewOperands(instr->shape(), operands);
292
293 TF_RETURN_IF_ERROR(fused_op->set_backend_config(config));
294 if (IsCublasGemm(*fused_op)) {
295 // Force bias input to alias with output, as GEMM operates in-place.
296 xla::Cast<HloCustomCallInstruction>(fused_op.get())
297 ->set_output_to_operand_aliasing({{{}, {2, {}}}});
298 }
299 TF_RETURN_IF_ERROR(SetName(instr->GetModule(), fused_op.get()));
300 TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(instr, std::move(fused_op)));
301 return OkStatus();
302 }
303
FuseVectorBiasAdd(HloInstruction * instr,HloInstruction * broadcast_bias,HloInstruction * matmul)304 StatusOr<bool> FuseVectorBiasAdd(HloInstruction *instr,
305 HloInstruction *broadcast_bias,
306 HloInstruction *matmul) {
307 TF_RET_CHECK(broadcast_bias->shape() == matmul->shape());
308
309 auto config = matmul->backend_config<GemmBackendConfig>().ValueOrDie();
310
311 // # output column dims == # non-contracting rhs operand dims.
312 const DotDimensionNumbers &dot_dims = config.dot_dimension_numbers();
313 size_t num_col_dims = matmul->operand(1)->shape().rank() -
314 dot_dims.rhs_batch_dimensions_size() -
315 dot_dims.rhs_contracting_dimensions_size();
316
317 HloInstruction *bias = broadcast_bias->mutable_operand(0);
318 if ((matmul->user_count() != 1) ||
319 (config.epilogue() != GemmBackendConfig::DEFAULT) ||
320 (bias->shape().rank() != num_col_dims)) {
321 return false;
322 }
323
324 // We require the bias vector to have been broadcast in the most major
325 // dimensions; i.e. its most minor physical dimensions align with most minor
326 // physical dimensions of the matmul output.
327 absl::Span<const int64_t> broadcast_dims = broadcast_bias->dimensions();
328 for (size_t i = 0; i < num_col_dims; ++i) {
329 int64_t dim = matmul->shape().layout().minor_to_major(i);
330
331 // Find the corresponding dimension from the bias vector.
332 auto it = absl::c_find(broadcast_dims, dim);
333
334 if (it == broadcast_dims.end()) {
335 return false;
336 }
337
338 int64_t vector_dim = it - broadcast_dims.begin();
339 if (bias->shape().layout().minor_to_major(i) != vector_dim) {
340 return false;
341 }
342 }
343
344 std::vector<HloInstruction *> operands(matmul->operands().begin(),
345 matmul->operands().end());
346 operands.push_back(bias);
347
348 std::unique_ptr<HloInstruction> fused_op =
349 matmul->CloneWithNewOperands(instr->shape(), operands);
350
351 config.set_epilogue(GemmBackendConfig::BIAS);
352 TF_RETURN_IF_ERROR(fused_op->set_backend_config(config));
353 TF_RETURN_IF_ERROR(SetName(instr->GetModule(), fused_op.get()));
354 TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(instr, std::move(fused_op)));
355 return true;
356 }
357 };
358
RunOnComputation(HloComputation * computation)359 StatusOr<bool> RunOnComputation(HloComputation *computation) {
360 GemmRewriterVisitor visitor;
361 TF_RETURN_IF_ERROR(computation->Accept(&visitor));
362 return visitor.changed();
363 }
364
365 } // anonymous namespace
366
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)367 StatusOr<bool> GemmRewriter::Run(
368 HloModule *module,
369 const absl::flat_hash_set<absl::string_view> &execution_threads) {
370 bool changed = false;
371 for (HloComputation *computation :
372 module->MakeNonfusionComputations(execution_threads)) {
373 TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
374 changed |= result;
375 }
376 return changed;
377 }
378
379 } // namespace gpu
380 } // namespace xla
381