1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h"
17
18 #include <algorithm>
19
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/types/span.h"
22 #include "tensorflow/compiler/xla/layout_util.h"
23 #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
24 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
25 #include "tensorflow/core/platform/errors.h"
26 #include "tensorflow/core/util/env_var.h"
27
28 namespace xla {
29 namespace gpu {
30
31 namespace {
32
GetUniqueOutputTypeOfFusible(const HloInstruction & fusible)33 PrimitiveType GetUniqueOutputTypeOfFusible(const HloInstruction& fusible) {
34 auto outputs = GetOutputsOfFusible(fusible);
35 CHECK(!outputs.empty());
36 PrimitiveType first_output_type = outputs[0]->shape().element_type();
37 for (size_t i = 1; i < outputs.size(); ++i) {
38 PrimitiveType cur_output_type = outputs[i]->shape().element_type();
39 CHECK(first_output_type == cur_output_type)
40 << "Output types are expected to be unique, but see "
41 << PrimitiveType_Name(first_output_type) << " and "
42 << PrimitiveType_Name(cur_output_type);
43 }
44
45 return first_output_type;
46 }
47
48 class HorizontalLoopFusionImpl {
49 public:
HorizontalLoopFusionImpl(HloComputation * computation)50 explicit HorizontalLoopFusionImpl(HloComputation* computation)
51 : computation_(computation) {}
52
~HorizontalLoopFusionImpl()53 ~HorizontalLoopFusionImpl() {}
54
55 StatusOr<bool> Run();
56
57 private:
58 Status Fuse(absl::Span<HloInstruction*> fused_fusion_instrs);
59
60 // Horizontally fuses `fused_fusion_instrs`. It is required that each of
61 // `fused_fusion_instrs` is a kLoop fusion. Also, we require their numbers of
62 // outputs to be the same, so that each output will be fused/concatenated with
63 // the same number of outputs from other fused fusion instrs. Then, all the
64 // fused outputs still have the same shapes for kernel generation.
65 //
66 // Returns the fused computation in `uniq_computation` and the operands that
67 // are used by `uniq_computation`.
68 Status CreateFusedComputation(
69 absl::Span<HloInstruction*> fused_fusion_instrs,
70 std::unique_ptr<HloComputation>* uniq_computation,
71 std::vector<HloInstruction*>* bound_operands);
72
73 // FusionCandidates collects profitable candidates for a given consumer
74 // instruction. GetNextSpanOfFusions() can then be iteratively invoked to
75 // acquire the next set of fusion candidates based on some heuristics.
76 class FusionCandidates {
77 public:
FusionCandidates(HloInstruction * consumer)78 explicit FusionCandidates(HloInstruction* consumer)
79 : fusible_instrs_(), pos_(0) {
80 Initialize(consumer);
81 }
82
83 // Gets a span of fusions to be fused.
84 absl::Span<HloInstruction*> GetNextSpanOfFusions();
85
86 private:
87 void Initialize(HloInstruction*);
88
89 std::vector<HloInstruction*> fusible_instrs_;
90 // `pos_` points to the start position of the next span.
91 size_t pos_;
92 };
93
94 HloComputation* computation_;
95 }; // HorizontalLoopFusionImpl
96
IsFusibleCandidate(const HloInstruction & instr)97 bool IsFusibleCandidate(const HloInstruction& instr) {
98 // Require no further check for element-wise instructions.
99 if (instr.IsElementwise() && instr.operand_count() > 0) {
100 return true;
101 }
102
103 // Exclude fusions other than kLoop.
104 if (!instr.IsLoopFusion()) {
105 return false;
106 }
107
108 // Cannot support fusion who has multiple output types, because the
109 // concatenate (inserted for horizontal fusion) requires the same type
110 // for all of its operands.
111 auto outputs = GetOutputsOfFusible(instr);
112 CHECK(!outputs.empty());
113 const HloInstruction* first_output = outputs[0];
114 for (size_t i = 1; i < outputs.size(); ++i) {
115 if (first_output->shape().element_type() !=
116 outputs[i]->shape().element_type()) {
117 return false;
118 }
119 }
120
121 return true;
122 }
123
124 // Returns whether `instr` is a profitable candidate to be horizontally fused.
125 // Since the primary benefit of horizontal fusion comes from reducing the
126 // kernel launch overhead, we want to exclude the instructions with
127 // insignificant kernel launch overhead. In other words, we exclude instructions
128 // if their computation latencies are longer than launch latencies. We estimate
129 // the computation latency of a given instruction by its shapes and the
130 // instruction count in its fused computation. We roughly observe that if a
131 // fusion instruction has shapes smaller than `kShapeThreshold` and has fewer
132 // instructions than `kInstrCountThreshold`, it is launch-latency-bound and
133 // profitable by horizontal fusion.
IsProfitableFusionCandidate(const HloInstruction & instr)134 bool IsProfitableFusionCandidate(const HloInstruction& instr) {
135 constexpr int64_t kShapeThreshold = 128 * 2048;
136 constexpr int64_t kInstrCountThreshold = 30;
137 const HloInstruction* root = (instr.opcode() == HloOpcode::kFusion)
138 ? instr.fused_expression_root()
139 : &instr;
140
141 // Too large shapes are not easily profitable.
142 if (root->opcode() == HloOpcode::kTuple) {
143 // Since all output shapes are the same, use the first shape as the
144 // representative.
145 Shape shape = root->operand(0)->shape();
146 if (ShapeUtil::ElementsIn(shape) > kShapeThreshold) {
147 return false;
148 }
149 } else {
150 Shape shape = root->shape();
151 if (ShapeUtil::ElementsIn(shape) > kShapeThreshold) {
152 return false;
153 }
154 }
155
156 // Having too many instructions is not easily profitable.
157 if (instr.opcode() == HloOpcode::kFusion &&
158 instr.fused_instruction_count() > kInstrCountThreshold) {
159 return false;
160 }
161
162 return true;
163 }
164
165 // Returns whether `fusion_instr` has only row-major layouts.
166 // The horizontal fusion excludes computations with non-row-major layouts,
167 // because fusing computations with different layouts can result in uncoalesced
168 // memory accesses and cause great performance overhead.
HasOnlyRowMajorLayout(const HloInstruction & instr)169 bool HasOnlyRowMajorLayout(const HloInstruction& instr) {
170 if (instr.opcode() != HloOpcode::kFusion) {
171 return LayoutUtil::IsMonotonicWithDim0Major(instr.shape().layout());
172 }
173
174 auto fused_instrs = instr.fused_instructions_computation()->instructions();
175 for (HloInstruction* i : fused_instrs) {
176 if (i->shape().layout().format() != DENSE) {
177 continue;
178 }
179 if (!LayoutUtil::IsMonotonicWithDim0Major(i->shape().layout())) {
180 return false;
181 }
182 }
183 return true;
184 }
185
186 // Returns whether any operand of `instr` is a parameter instruction that
187 // is shared with `fusion_instrs`.
AnyOpndIsParamSharedAmongFusions(const HloInstruction * instr,const absl::flat_hash_set<HloInstruction * > & fusion_instrs)188 bool AnyOpndIsParamSharedAmongFusions(
189 const HloInstruction* instr,
190 const absl::flat_hash_set<HloInstruction*>& fusion_instrs) {
191 return absl::c_any_of(instr->operands(), [&](const HloInstruction* opnd) {
192 return opnd->opcode() == HloOpcode::kParameter &&
193 absl::c_any_of(opnd->users(), [&](const HloInstruction* user) {
194 return user != instr && fusion_instrs.contains(user);
195 });
196 });
197 }
198
Initialize(HloInstruction * consumer)199 void HorizontalLoopFusionImpl::FusionCandidates::Initialize(
200 HloInstruction* consumer) {
201 // First, find out all potential target candidates. We will filter out
202 // unsupported/non-profitable cases below.
203 absl::flat_hash_set<HloInstruction*> fusible_candidates;
204 std::vector<HloInstruction*> ordered_fusible_candidates;
205 for (HloInstruction* opnd : consumer->operands()) {
206 HloInstruction* predecessor = opnd->LatestNonGteAncestor();
207 // We support kLoop fusion and element-wise HLOs now. We may extend the
208 // support list if needs arise.
209 if (IsFusibleCandidate(*predecessor)) {
210 if (fusible_candidates.insert(predecessor).second) {
211 // Add unseen fusion to ordered list.
212 ordered_fusible_candidates.push_back(predecessor);
213 }
214 }
215 }
216
217 for (HloInstruction* instr : ordered_fusible_candidates) {
218 if (!IsConsumerTheOnlyNonRootUser(*instr, *consumer)) {
219 VLOG(2) << "Reject maybe illegal instr " << instr->ToString()
220 << "; including it may create cycles in HLO.";
221 continue;
222 } else if (!IsProfitableFusionCandidate(*instr)) {
223 VLOG(2) << "Reject may-not-be profitable fusion instr "
224 << instr->ToString();
225 continue;
226 } else if (!HasOnlyRowMajorLayout(*instr)) {
227 VLOG(2) << "Reject non-row-major fusion instr " << instr->ToString();
228 continue;
229 } else if (AnyOpndIsParamSharedAmongFusions(instr, fusible_candidates)) {
230 // Don't fuse fusions whose operands are parameter instructions that are
231 // shared among fusions because we cannot i/o alias the produced
232 // horizontal fusion due to the concat insertion.
233 VLOG(2) << "Reject the fusion instr because it shares parameter with"
234 << " other fusion candidates, instr: " << instr->ToString();
235 continue;
236 } else {
237 VLOG(2) << "Find a fusion candidate " << instr->ToString();
238 // Encapsulate it into a fusion computation for unified representation
239 // for later processing.
240 fusible_instrs_.push_back(instr);
241 }
242 }
243
244 // Sort `fusible_instrs_` according to output types, the number of outputs,
245 // and instruction counts, because we only fuse instructions with the same
246 // number/type of outputs and whose computations have the same instruction
247 // count.
248 std::sort(
249 fusible_instrs_.begin(), fusible_instrs_.end(),
250 [&](const HloInstruction* a, const HloInstruction* b) {
251 if (GetUniqueOutputTypeOfFusible(*a) !=
252 GetUniqueOutputTypeOfFusible(*b)) {
253 return GetUniqueOutputTypeOfFusible(*a) <
254 GetUniqueOutputTypeOfFusible(*b);
255 } else if (GetOutputSizeOfFusible(*a) != GetOutputSizeOfFusible(*b)) {
256 return GetOutputSizeOfFusible(*a) < GetOutputSizeOfFusible(*b);
257 } else {
258 return GetInstrCountOfFusible(*a) < GetInstrCountOfFusible(*b);
259 }
260 });
261 }
262
263 // Gets a next span of fusion instructions to be fused.
264 absl::Span<HloInstruction*>
GetNextSpanOfFusions()265 HorizontalLoopFusionImpl::FusionCandidates::GetNextSpanOfFusions() {
266 if (pos_ >= fusible_instrs_.size()) {
267 return absl::Span<HloInstruction*>();
268 }
269
270 // Fusing too many computations at a time may not be easily profitable and
271 // may increase compile time due to large kernels. Set a limit to it.
272 constexpr int64_t kMaxFusionBatchSize = 32;
273 // CUDA has a parameter size limit of ~4k bytes.
274 constexpr int64_t kMaxCudaParamSize = 4000;
275 size_t accum_io_size = 0;
276 auto reach_max_fusion_batch_size = [&](size_t left, size_t right) -> bool {
277 if (right - left >= kMaxFusionBatchSize) {
278 return true;
279 }
280
281 accum_io_size += fusible_instrs_.at(right)->operand_count() +
282 GetOutputSizeOfFusible(*fusible_instrs_.at(right));
283
284 if (accum_io_size * 8 >= kMaxCudaParamSize) {
285 return true;
286 }
287
288 return false;
289 };
290
291 size_t left = pos_;
292 size_t right = pos_ + 1;
293 size_t first_output_size = GetOutputSizeOfFusible(*fusible_instrs_[left]);
294 PrimitiveType first_output_type =
295 GetUniqueOutputTypeOfFusible(*fusible_instrs_[left]);
296 for (; right < fusible_instrs_.size(); ++right) {
297 PrimitiveType cur_output_type =
298 GetUniqueOutputTypeOfFusible(*fusible_instrs_[right]);
299 if (first_output_type != cur_output_type) {
300 // Cannot fuse computations who have multiple output types.
301 break;
302 } else if (first_output_size !=
303 GetOutputSizeOfFusible(*fusible_instrs_[right])) {
304 // Cannot fuse computations who have different numbers of outputs.
305 break;
306 } else if (GetInstrCountOfFusible(*fusible_instrs_[left]) !=
307 GetInstrCountOfFusible(*fusible_instrs_[right])) {
308 // Do not fuse computations of different instruction counts as it may
309 // introduce control divergence. This is a very simple heuristic to avoid
310 // fusing computations with too much discrepancy and we may improve it
311 // when the needs arise.
312 break;
313 } else if (reach_max_fusion_batch_size(left, right)) {
314 // Hit max fusion batch size.
315 break;
316 }
317 }
318
319 pos_ = right;
320 return absl::MakeSpan(fusible_instrs_).subspan(left, right - left);
321 }
322
CreateFusedComputation(absl::Span<HloInstruction * > fused_fusion_instrs,std::unique_ptr<HloComputation> * uniq_computation,std::vector<HloInstruction * > * bound_operands)323 Status HorizontalLoopFusionImpl::CreateFusedComputation(
324 absl::Span<HloInstruction*> fused_fusion_instrs,
325 std::unique_ptr<HloComputation>* uniq_computation,
326 std::vector<HloInstruction*>* bound_operands) {
327 // First, build a computation with only params.
328 HloComputation::Builder b("horizontally_fused_computation");
329 size_t fused_comp_param_id = 0;
330 for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
331 auto old_params = fused_fusion_instrs[i]->fused_parameters();
332 for (size_t j = 0; j < old_params.size(); ++j) {
333 HloInstruction* bound_opnd = fused_fusion_instrs[i]->mutable_operand(j);
334 // in a form of param_i_j
335 b.AddInstruction(HloInstruction::CreateParameter(
336 fused_comp_param_id++, bound_opnd->shape(),
337 absl::StrCat("param_", i, "_", j)));
338 bound_operands->push_back(bound_opnd);
339 }
340 }
341 // Always create a dummy tuple instruction to serve as the root of the
342 // computation, as the existence of a root instruction is required by the
343 // HloComputation. The real root instruction will replace it below.
344 HloInstruction* dummy_root = b.AddInstruction(
345 HloInstruction::CreateTuple(std::vector<HloInstruction*>{}));
346 *uniq_computation = b.Build(dummy_root);
347 HloComputation* comp = uniq_computation->get();
348
349 // Preparing clone_map, which maps old operand to new operand.
350 absl::flat_hash_map<const HloInstruction*, HloInstruction*> clone_map;
351 size_t new_param_id = 0;
352 for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
353 auto old_params = fused_fusion_instrs[i]->fused_parameters();
354 for (size_t j = 0; j < old_params.size(); ++j) {
355 HloInstruction* old_param = old_params[j];
356 HloInstruction* new_param = comp->parameter_instruction(new_param_id++);
357 clone_map.insert({old_param, new_param});
358 }
359 }
360
361 // Clone every fused computation.
362 for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
363 auto def_to_use_order = fused_fusion_instrs[i]
364 ->fused_instructions_computation()
365 ->MakeInstructionPostOrder();
366 for (HloInstruction* old_instr : def_to_use_order) {
367 if (old_instr->opcode() == HloOpcode::kParameter) {
368 // Parameters have been created.
369 continue;
370 }
371 std::vector<HloInstruction*> new_opnds;
372 for (HloInstruction* old_opnd : old_instr->operands()) {
373 CHECK(clone_map.find(old_opnd) != clone_map.end());
374 new_opnds.push_back(clone_map[old_opnd]);
375 }
376 HloInstruction* new_instr = comp->AddInstruction(
377 old_instr->CloneWithNewOperands(old_instr->shape(), new_opnds));
378 clone_map.insert({old_instr, new_instr});
379 }
380 }
381
382 std::vector<HloInstruction*> concated_outputs;
383 // Since we require each fusion to have the same number of outputs, we can
384 // simply use the first fusion as the representative for output size.
385 size_t fused_instr_output_size =
386 GetOutputSizeOfFusible(*fused_fusion_instrs[0]);
387 for (size_t i = 0; i < fused_instr_output_size; ++i) {
388 std::vector<HloInstruction*> reshapes(fused_fusion_instrs.size());
389 for (size_t j = 0; j < fused_fusion_instrs.size(); ++j) {
390 const HloInstruction* old_output =
391 GetOutputsOfFusible(*fused_fusion_instrs[j])[i];
392 HloInstruction* new_output = clone_map[old_output];
393 TF_ASSIGN_OR_RETURN(
394 reshapes[j],
395 MakeReshapeHlo(ShapeUtil::MakeShapeWithLayout(
396 new_output->shape().element_type(),
397 {ShapeUtil::ElementsIn(new_output->shape())},
398 /*minor_to_major=*/std::vector<int64>(1, 0)),
399 new_output));
400 }
401 TF_ASSIGN_OR_RETURN(HloInstruction * concated_output,
402 MakeConcatHlo(reshapes, 0));
403 concated_outputs.push_back(concated_output);
404 }
405
406 // Make slices of outputs.
407 std::vector<HloInstruction*> output_slices(concated_outputs.size() *
408 fused_fusion_instrs.size());
409 for (size_t i = 0; i < concated_outputs.size(); ++i) {
410 HloInstruction* concated_output = concated_outputs[i];
411 int64_t slice_start = 0;
412 // Create a slice per fused computation.
413 for (size_t j = 0; j < fused_fusion_instrs.size(); ++j) {
414 const HloInstruction* old_output =
415 GetOutputsOfFusible(*fused_fusion_instrs[j])[i];
416 Shape shape = old_output->shape();
417 int64_t slice_limit = slice_start + ShapeUtil::ElementsIn(shape);
418 TF_ASSIGN_OR_RETURN(
419 output_slices[concated_outputs.size() * j + i],
420 MakeSliceHlo(concated_output, {slice_start}, {slice_limit},
421 /*strides=*/{1}));
422 slice_start = slice_limit;
423 }
424 }
425
426 // Make a tuple of output_slices.
427 HloInstruction* tuple =
428 comp->AddInstruction(HloInstruction::CreateTuple(output_slices));
429 comp->set_root_instruction(tuple, /*accept_different_shape=*/true);
430 TF_RETURN_IF_ERROR(comp->RemoveInstruction(dummy_root));
431
432 return Status::OK();
433 }
434
Fuse(absl::Span<HloInstruction * > fused_fusion_instrs)435 Status HorizontalLoopFusionImpl::Fuse(
436 absl::Span<HloInstruction*> fused_fusion_instrs) {
437 // Fuse fused_fusion_instrs and replace them with the new fused computation.
438 std::unique_ptr<HloComputation> uniq_computation;
439 std::vector<HloInstruction*> bound_operands;
440 TF_RETURN_IF_ERROR(CreateFusedComputation(
441 fused_fusion_instrs, &uniq_computation, &bound_operands));
442 HloComputation* fused_comp = computation_->parent()->AddEmbeddedComputation(
443 std::move(uniq_computation));
444 HloInstruction* hori_fusion_instr =
445 computation_->AddInstruction(HloInstruction::CreateFusion(
446 fused_comp->root_instruction()->shape(),
447 HloInstruction::FusionKind::kInput, bound_operands, fused_comp));
448 fused_comp->SetFusionInstruction(hori_fusion_instr);
449
450 // Insert bitcasts and replace corresponding users. Note that we do not insert
451 // the bitcasts in the fused computation as it does not fit into the slice
452 // input fusion pattern. However, inserting bitcasts outside the fused
453 // computation creates no performance cost.
454 size_t total_output_id = 0;
455 for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
456 std::vector<HloInstruction*> bitcasts;
457 HloInstruction* fused_instr = fused_fusion_instrs[i];
458 size_t num_outputs = GetOutputSizeOfFusible(*fused_instr);
459 for (size_t j = 0; j < num_outputs; ++j) {
460 const HloInstruction* output = GetOutputsOfFusible(*fused_instr)[j];
461 TF_ASSIGN_OR_RETURN(
462 HloInstruction * gep,
463 MakeGetTupleElementHlo(hori_fusion_instr, total_output_id++));
464 bitcasts.push_back(computation_->AddInstruction(
465 HloInstruction::CreateBitcast(output->shape(), gep)));
466 }
467 HloInstruction* bitcast_or_tuple =
468 (bitcasts.size() == 1) ? bitcasts.at(0)
469 : computation_->AddInstruction(
470 HloInstruction::CreateTuple(bitcasts));
471 TF_RETURN_IF_ERROR(
472 computation_->ReplaceInstruction(fused_instr, bitcast_or_tuple));
473 }
474
475 return Status::OK();
476 }
477
Run()478 StatusOr<bool> HorizontalLoopFusionImpl::Run() {
479 bool changed = false;
480 XLA_VLOG_LINES(3, computation_->ToString());
481
482 // Traverse from use to def. Bitcasts are placed after h-fusions to resolve
483 // shape mismatch but bitcasts could prevent future h-fusion from happening.
484 // So, a bottom-up, use-to-def order should be more favorable. It also helps
485 // to save compiler iterations to reach the fixed point.
486 std::vector<HloInstruction*> use_to_def_order =
487 computation_->MakeInstructionPostOrder();
488 absl::c_reverse(use_to_def_order);
489 for (size_t i = 0; i < use_to_def_order.size(); ++i) {
490 HloInstruction* consumer = use_to_def_order[i];
491 HorizontalLoopFusionImpl::FusionCandidates fusion_candidates(consumer);
492 while (true) {
493 auto fusibles = fusion_candidates.GetNextSpanOfFusions();
494 if (fusibles.empty()) {
495 break;
496 } else if (fusibles.size() == 1) {
497 // Skip; there is just one fused_instr.
498 continue;
499 }
500
501 changed = true;
502 // Convert fusible into fusion_instrs to simplify the implementation of
503 // `Fuse()`.
504 std::vector<HloInstruction*> fusion_instrs;
505 for (HloInstruction* instr : fusibles) {
506 if (instr->opcode() == HloOpcode::kFusion) {
507 fusion_instrs.push_back(instr);
508 } else {
509 TF_ASSIGN_OR_RETURN(
510 HloInstruction * fusion_instr,
511 MakeFusionInstruction(instr, HloInstruction::FusionKind::kLoop));
512 fusion_instrs.push_back(fusion_instr);
513 }
514 }
515 TF_RETURN_IF_ERROR(Fuse(absl::MakeSpan(fusion_instrs)));
516 }
517 }
518
519 return changed;
520 }
521
522 } // namespace
523
RunOnComputation(HloComputation * computation)524 StatusOr<bool> GpuHorizontalLoopFusion::RunOnComputation(
525 HloComputation* computation) {
526 HorizontalLoopFusionImpl horizontal_fusion_impl(computation);
527 return horizontal_fusion_impl.Run();
528 }
529
Run(HloModule * module)530 StatusOr<bool> GpuHorizontalLoopFusion::Run(HloModule* module) {
531 bool changed = false;
532 VLOG(2) << "Run horizontal fusion.";
533
534 // Run on the entry computation is actually enough.
535 TF_ASSIGN_OR_RETURN(changed, RunOnComputation(module->entry_computation()));
536
537 return changed;
538 }
539
540 } // namespace gpu
541 } // namespace xla
542