• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/gemm_rewriter.h"
17 
18 #include <memory>
19 #include <utility>
20 #include <vector>
21 
22 #include "absl/algorithm/container.h"
23 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
24 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
25 #include "tensorflow/compiler/xla/service/gpu/cublas_cudnn.h"
26 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
27 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
28 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
29 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
30 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
31 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
32 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
33 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
34 #include "tensorflow/compiler/xla/status_macros.h"
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/compiler/xla/xla_data.pb.h"
37 #include "tensorflow/core/lib/core/errors.h"
38 #include "tensorflow/stream_executor/lib/statusor.h"
39 
40 namespace xla {
41 namespace gpu {
42 namespace {
43 
44 namespace m = match;
45 
46 // Give this instruction a more useful name than "custom-call.42".
SetName(HloModule * module,HloInstruction * gemm)47 Status SetName(HloModule *module, HloInstruction *gemm) {
48   if (IsCublasLtMatmul(*gemm)) {
49     module->SetAndUniquifyInstrName(gemm, "cublas-lt-matmul");
50     return OkStatus();
51   }
52 
53   GemmBackendConfig config;
54   TF_ASSIGN_OR_RETURN(config, gemm->backend_config<GemmBackendConfig>());
55   const DotDimensionNumbers &dot_dims = config.dot_dimension_numbers();
56   bool is_batch_dot = !dot_dims.lhs_batch_dimensions().empty() ||
57                       !dot_dims.rhs_batch_dimensions().empty();
58 
59   module->SetAndUniquifyInstrName(
60       gemm, is_batch_dot ? "cublas-batch-gemm" : "cublas-gemm");
61   return OkStatus();
62 }
63 
64 // If the bias is a sequence of ops that depend only on broadcasts of
65 // constants, materialize the bias if it's small.
66 //
67 // Normally the constant-folding pass would materialize the bias if it is
68 // calculated entirely from constants. But if the bias is a broadcast of a
69 // constant, constant-folding won't expand the broadcast, on the theory that
70 // folding broadcasts of constants causes us to consume more memory and can
71 // actually make things slower (because any op which reads the constant has
72 // to read more memory).
73 //
74 // OTOH in our case, we don't want to run an op that just broadcasts a
75 // constant so we can fuse it into this gemm. That would defeat the whole
76 // purpose of this fusion, which is to launch fewer kernels.  So if we can,
77 // we expand out this constant ourselves.
78 //
79 // TODO(b/192499646): Even better would be to use cublasLT to fuse the
80 // broadcasted bias, if it supports that fusion efficiently.
MaybeConstantFoldBias(HloInstruction * bias)81 HloInstruction *MaybeConstantFoldBias(HloInstruction *bias) {
82   // This limit was not chosen carefully.
83   constexpr int kMaxMaterializeBiasBytes = 8 * 1024 * 1024;
84 
85   // Don't fold broadcasts of scalars -- algsimp will just collapse it again.
86   auto is_nonscalar = [](const HloInstruction *instr) {
87     return !ShapeUtil::IsEffectiveScalar(instr->shape());
88   };
89 
90   // For now, only fold broadcast(constant) or
91   // reshape/transpose/bitcast(broadcast(constant)). This lets us avoid the
92   // complexity in the constant-folding pass about what is and isn't legal to
93   // fold.
94   auto broadcast_of_nonscalar =
95       m::Broadcast(m::Constant().WithPredicate(is_nonscalar));
96 
97   if (ShapeUtil::ByteSizeOf(bias->shape()) <= kMaxMaterializeBiasBytes &&
98       (Match(bias, broadcast_of_nonscalar) ||
99        Match(bias, m::Reshape(broadcast_of_nonscalar)) ||
100        Match(bias, m::Transpose(broadcast_of_nonscalar)) ||
101        Match(bias, m::Bitcast(broadcast_of_nonscalar)))) {
102     HloEvaluator evaluator(/*max_loop_iterations=*/0);
103     Literal result;
104     if (evaluator.TryEvaluate(
105             bias, &result,
106             /*recursively_evaluate_nonconstant_operands=*/true)) {
107       return bias->parent()->AddInstruction(
108           HloInstruction::CreateConstant(std::move(result)));
109     }
110   }
111 
112   return bias;
113 }
114 
115 // The rewriting proceeds in a bottom-up way:
116 //
117 // (kDot A B) is rewritten into a (kCustomCall:gemm A B)
118 //
119 // (kMultiply (kCustomCall:gemm A B) C) is folding C (provided it's a constant)
120 // into an alpha parameter of the custom call.
121 //
122 // (kAdd (kCustomCall:gemm A B) C) is rewritten into (kCustomCall:gemm A B C),
123 // where the "beta" parameter is set to 1 (provided it was zero before,
124 // and provided C has no other users).
125 // We then guide the buffer assignment to alias the buffer of the custom call
126 // and C.
127 class GemmRewriterVisitor : public DfsHloRewriteVisitor {
128  public:
HandleDot(HloInstruction * instr)129   Status HandleDot(HloInstruction *instr) override {
130     if (IsMatrixMultiplication(*instr)) {
131       CHECK(!instr->IsRank2Transpose());
132       HloInstruction *lhs = instr->mutable_operand(0);
133       HloInstruction *rhs = instr->mutable_operand(1);
134       CHECK(!lhs->IsRank2Transpose());
135       CHECK(!rhs->IsRank2Transpose());
136       const Shape &output_shape = instr->shape();
137 
138       const char *const target =
139           instr->GetModule()->config().debug_options().xla_gpu_enable_cublaslt()
140               ? kCublasLtMatmulCallTarget
141               : kGemmCallTarget;
142 
143       std::unique_ptr<HloInstruction> gemm_call =
144           HloInstruction::CreateCustomCall(output_shape, {lhs, rhs}, target);
145       GemmBackendConfig gemm_config;
146       gemm_config.set_alpha_real(1.0);
147       gemm_config.set_alpha_imag(0.0);
148       gemm_config.set_beta(0.0);
149       *gemm_config.mutable_dot_dimension_numbers() =
150           instr->dot_dimension_numbers();
151       *gemm_config.mutable_precision_config() = instr->precision_config();
152 
153       TF_RETURN_IF_ERROR(gemm_call->set_backend_config(gemm_config));
154       TF_RETURN_IF_ERROR(SetName(instr->GetModule(), gemm_call.get()));
155       TF_RETURN_IF_ERROR(
156           ReplaceWithNewInstruction(instr, std::move(gemm_call)));
157     }
158     return OkStatus();
159   }
160 
HandleMultiply(HloInstruction * instr)161   Status HandleMultiply(HloInstruction *instr) override {
162     HloInstruction *alpha, *existing_gemm;
163     if (Match(instr,
164               m::MultiplyAnyOrder(
165                   m::Op(&existing_gemm).WithCustomCallTarget(kGemmCallTarget),
166                   m::Broadcast(m::ConstantScalar(&alpha))))) {
167       TF_ASSIGN_OR_RETURN(auto config,
168                           existing_gemm->backend_config<GemmBackendConfig>());
169 
170       // Do not fuse alpha into S32 GEMM, as they only support fixed values for
171       // alpha/beta.
172       if (existing_gemm->shape().element_type() == S32) {
173         return OkStatus();
174       }
175 
176       if (config.beta() == 0.0 && existing_gemm->user_count() == 1) {
177         complex128 prev_alpha = {config.alpha_real(), config.alpha_imag()};
178         complex128 new_alpha =
179             *alpha->literal().GetAsComplex128({}) * prev_alpha;
180         config.set_alpha_real(new_alpha.real());
181         config.set_alpha_imag(new_alpha.imag());
182         TF_RETURN_IF_ERROR(existing_gemm->set_backend_config(config));
183         TF_RETURN_IF_ERROR(ReplaceInstruction(instr, existing_gemm));
184       }
185     }
186     return OkStatus();
187   }
188 
HandleAdd(HloInstruction * instr)189   Status HandleAdd(HloInstruction *instr) override {
190     HloInstruction *bias, *existing_gemm;
191 
192     // First, try to match vector bias add, so we might elide the broadcast.
193     if (Match(instr, m::AddAnyOrder(
194                          m::Op(&existing_gemm)
195                              .WithCustomCallTarget(kCublasLtMatmulCallTarget),
196                          m::Broadcast(&bias, m::Op())))) {
197       TF_ASSIGN_OR_RETURN(bool was_fused,
198                           FuseVectorBiasAdd(instr, bias, existing_gemm));
199       if (was_fused) return OkStatus();
200     }
201 
202     // add(bitcast(gemm(a, b)), bias) ->
203     //   bitcast(add(gemm(a, b), bitcast(bias))) ->
204     //   bitcast(gemm(a, b, bitcast(bias))) (later down in this function).
205     //
206     // We see this idiom in models that contain batch-dots, where we cast
207     // between a rank-2 shape for non-batch dots and a higher-rank shape for
208     // batch-dots.
209     //
210     // The last stage of the transform may fail (because of any of the checks in
211     // FuseBiasedGemm), but if so that's okay -- we'll have done a useless
212     // transformation, but it doesn't hurt anything.
213     if (Match(instr, m::AddAnyOrder(
214                          m::Bitcast(m::Op(&existing_gemm)
215                                         .WithCustomCallTarget(kGemmCallTarget)
216                                         .WithOneUser())
217                              .WithOneUser(),
218                          m::Op(&bias)))) {
219       HloInstruction *new_bitcast =
220           MakeBitcastHlo(bias, existing_gemm->shape(), &bias->metadata());
221       TF_ASSIGN_OR_RETURN(HloInstruction * new_add,
222                           MakeBinaryHlo(HloOpcode::kAdd, existing_gemm,
223                                         new_bitcast, &bias->metadata()));
224       TF_RETURN_IF_ERROR(
225           ReplaceInstruction(instr, MakeBitcastHlo(new_add, instr->shape())));
226 
227       // Continue below transforming new_add.
228       instr = new_add;
229     }
230 
231     if (Match(instr, m::AddAnyOrder(
232                          m::Op(&existing_gemm)
233                              .WithCustomCallTarget(
234                                  {kGemmCallTarget, kCublasLtMatmulCallTarget}),
235                          m::Op(&bias)))) {
236       return FuseMatrixBiasAdd(instr, bias, existing_gemm);
237     }
238 
239     return Status::OK();
240   }
241 
HandleConvert(HloInstruction * instr)242   Status HandleConvert(HloInstruction *instr) override {
243     HloInstruction *bias, *existing_gemm;
244     if (Match(
245             instr,
246             m::Convert(m::AddAnyOrder(
247                            m::Convert(m::Op(&existing_gemm)
248                                           .WithCustomCallTarget(kGemmCallTarget)
249                                           .WithElementType(BF16)),
250                            m::Convert(m::Op(&bias).WithElementType(BF16))))
251                 .WithElementType(BF16))) {
252       return FuseMatrixBiasAdd(instr, bias, existing_gemm);
253     }
254     return OkStatus();
255   }
256 
FuseMatrixBiasAdd(HloInstruction * instr,HloInstruction * bias,HloInstruction * gemm)257   Status FuseMatrixBiasAdd(HloInstruction *instr, HloInstruction *bias,
258                            HloInstruction *gemm) {
259     TF_RET_CHECK(bias->shape() == gemm->shape());
260 
261     // Do not fuse bias into S32 GEMM, as for this datatype cuBLAS only
262     // supports fixed values for alpha/beta.
263     if (gemm->shape().element_type() == S32) {
264       return OkStatus();
265     }
266 
267     // BLAS GeMM overwrites bias matrix, so fusion is only possible if the GeMM
268     // is the only user. cublasLt matmul can operate out-of-place.
269     bool can_fuse_bias = (bias->user_count() == 1) || IsCublasLtMatmul(*gemm);
270 
271     auto config = gemm->backend_config<GemmBackendConfig>().ValueOrDie();
272 
273     // It is possible to fuse into a cublasLt matmul that already has a vector
274     // bias, but no other epilogue will commute with the matrix bias add.
275     bool supported_epilogue =
276         ((config.epilogue() == GemmBackendConfig::DEFAULT) ||
277          (config.epilogue() == GemmBackendConfig::BIAS));
278 
279     if ((config.beta() != 0) || !can_fuse_bias || (gemm->user_count() != 1) ||
280         !supported_epilogue) {
281       return OkStatus();
282     }
283 
284     config.set_beta(1.0);
285 
286     std::vector<HloInstruction *> operands(gemm->operands().begin(),
287                                            gemm->operands().end());
288     operands.insert(operands.begin() + 2, MaybeConstantFoldBias(bias));
289 
290     std::unique_ptr<HloInstruction> fused_op =
291         gemm->CloneWithNewOperands(instr->shape(), operands);
292 
293     TF_RETURN_IF_ERROR(fused_op->set_backend_config(config));
294     if (IsCublasGemm(*fused_op)) {
295       // Force bias input to alias with output, as GEMM operates in-place.
296       xla::Cast<HloCustomCallInstruction>(fused_op.get())
297           ->set_output_to_operand_aliasing({{{}, {2, {}}}});
298     }
299     TF_RETURN_IF_ERROR(SetName(instr->GetModule(), fused_op.get()));
300     TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(instr, std::move(fused_op)));
301     return OkStatus();
302   }
303 
FuseVectorBiasAdd(HloInstruction * instr,HloInstruction * broadcast_bias,HloInstruction * matmul)304   StatusOr<bool> FuseVectorBiasAdd(HloInstruction *instr,
305                                    HloInstruction *broadcast_bias,
306                                    HloInstruction *matmul) {
307     TF_RET_CHECK(broadcast_bias->shape() == matmul->shape());
308 
309     auto config = matmul->backend_config<GemmBackendConfig>().ValueOrDie();
310 
311     // # output column dims == # non-contracting rhs operand dims.
312     const DotDimensionNumbers &dot_dims = config.dot_dimension_numbers();
313     size_t num_col_dims = matmul->operand(1)->shape().rank() -
314                           dot_dims.rhs_batch_dimensions_size() -
315                           dot_dims.rhs_contracting_dimensions_size();
316 
317     HloInstruction *bias = broadcast_bias->mutable_operand(0);
318     if ((matmul->user_count() != 1) ||
319         (config.epilogue() != GemmBackendConfig::DEFAULT) ||
320         (bias->shape().rank() != num_col_dims)) {
321       return false;
322     }
323 
324     // We require the bias vector to have been broadcast in the most major
325     // dimensions; i.e. its most minor physical dimensions align with most minor
326     // physical dimensions of the matmul output.
327     absl::Span<const int64_t> broadcast_dims = broadcast_bias->dimensions();
328     for (size_t i = 0; i < num_col_dims; ++i) {
329       int64_t dim = matmul->shape().layout().minor_to_major(i);
330 
331       // Find the corresponding dimension from the bias vector.
332       auto it = absl::c_find(broadcast_dims, dim);
333 
334       if (it == broadcast_dims.end()) {
335         return false;
336       }
337 
338       int64_t vector_dim = it - broadcast_dims.begin();
339       if (bias->shape().layout().minor_to_major(i) != vector_dim) {
340         return false;
341       }
342     }
343 
344     std::vector<HloInstruction *> operands(matmul->operands().begin(),
345                                            matmul->operands().end());
346     operands.push_back(bias);
347 
348     std::unique_ptr<HloInstruction> fused_op =
349         matmul->CloneWithNewOperands(instr->shape(), operands);
350 
351     config.set_epilogue(GemmBackendConfig::BIAS);
352     TF_RETURN_IF_ERROR(fused_op->set_backend_config(config));
353     TF_RETURN_IF_ERROR(SetName(instr->GetModule(), fused_op.get()));
354     TF_RETURN_IF_ERROR(ReplaceWithNewInstruction(instr, std::move(fused_op)));
355     return true;
356   }
357 };
358 
RunOnComputation(HloComputation * computation)359 StatusOr<bool> RunOnComputation(HloComputation *computation) {
360   GemmRewriterVisitor visitor;
361   TF_RETURN_IF_ERROR(computation->Accept(&visitor));
362   return visitor.changed();
363 }
364 
365 }  // anonymous namespace
366 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)367 StatusOr<bool> GemmRewriter::Run(
368     HloModule *module,
369     const absl::flat_hash_set<absl::string_view> &execution_threads) {
370   bool changed = false;
371   for (HloComputation *computation :
372        module->MakeNonfusionComputations(execution_threads)) {
373     TF_ASSIGN_OR_RETURN(bool result, RunOnComputation(computation));
374     changed |= result;
375   }
376   return changed;
377 }
378 
379 }  // namespace gpu
380 }  // namespace xla
381