1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h"
17
18 #include <algorithm>
19
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/types/span.h"
22 #include "tensorflow/compiler/xla/layout_util.h"
23 #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
24 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
25 #include "tensorflow/core/platform/errors.h"
26 #include "tensorflow/core/util/env_var.h"
27
28 namespace xla {
29 namespace gpu {
30
31 namespace {
32
GetOutputsOfFusion(const HloInstruction & instr)33 absl::InlinedVector<HloInstruction*, 2> GetOutputsOfFusion(
34 const HloInstruction& instr) {
35 CHECK(instr.opcode() == HloOpcode::kFusion);
36 HloInstruction* root = instr.fused_expression_root();
37 if (root->opcode() != HloOpcode::kTuple) {
38 return {root};
39 } else {
40 return root->operands();
41 }
42 }
43
44 // Returns the number of outputs of the fused computation.
GetOutputSizeOfFusion(const HloInstruction & instr)45 size_t GetOutputSizeOfFusion(const HloInstruction& instr) {
46 CHECK(instr.opcode() == HloOpcode::kFusion);
47 const HloInstruction* root = instr.fused_expression_root();
48 if (root->opcode() != HloOpcode::kTuple) {
49 return 1;
50 } else {
51 return ShapeUtil::TupleElementCount(root->shape());
52 }
53 }
54
GetUniqueOutputTypeOfFusion(const HloInstruction & instr)55 PrimitiveType GetUniqueOutputTypeOfFusion(const HloInstruction& instr) {
56 auto outputs = GetOutputsOfFusion(instr);
57 CHECK(!outputs.empty());
58 PrimitiveType first_output_type = outputs[0]->shape().element_type();
59 for (size_t i = 1; i < outputs.size(); ++i) {
60 PrimitiveType cur_output_type = outputs[i]->shape().element_type();
61 CHECK(first_output_type == cur_output_type)
62 << "Output types are expected to be unique, but see "
63 << PrimitiveType_Name(first_output_type) << " and "
64 << PrimitiveType_Name(cur_output_type);
65 }
66
67 return first_output_type;
68 }
69
70 class HorizontalLoopFusionImpl {
71 public:
HorizontalLoopFusionImpl(HloComputation * computation)72 explicit HorizontalLoopFusionImpl(HloComputation* computation)
73 : computation_(computation) {}
74
~HorizontalLoopFusionImpl()75 ~HorizontalLoopFusionImpl() {}
76
77 StatusOr<bool> Run();
78
79 private:
80 Status Fuse(absl::Span<HloInstruction*> fused_fusion_instrs);
81
82 // Horizontally fuses `fused_fusion_instrs`. It is required that each of
83 // `fused_fusion_instrs` is a kLoop fusion. Also, we require their numbers of
84 // outputs to be the same, so that each output will be fused/concatenated with
85 // the same number of outputs from other fused fusion instrs. Then, all the
86 // fused outputs still have the same shapes for kernel generation.
87 //
88 // Returns the fused computation in `uniq_computation` and the operands that
89 // are used by `uniq_computation`.
90 Status CreateFusedComputation(
91 absl::Span<HloInstruction*> fused_fusion_instrs,
92 std::unique_ptr<HloComputation>* uniq_computation,
93 std::vector<HloInstruction*>* bound_operands);
94
95 // FusionCandidates collects profitable candidates for a given consumer
96 // instruction. GetNextSpanOfFusions() can then be iteratively invoked to
97 // acquire the next set of fusion candidates based on some heuristics.
98 class FusionCandidates {
99 public:
FusionCandidates(HloInstruction * consumer)100 explicit FusionCandidates(HloInstruction* consumer)
101 : fusion_instrs_(), pos_(0) {
102 Initialize(consumer);
103 }
104
105 // Gets a span of fusions to be fused.
106 absl::Span<HloInstruction*> GetNextSpanOfFusions();
107
108 private:
109 void Initialize(HloInstruction*);
110
111 std::vector<HloInstruction*> fusion_instrs_;
112 // `pos_` points to the start position of the next span.
113 size_t pos_;
114 };
115
116 HloComputation* computation_;
117 }; // HorizontalLoopFusionImpl
118
IsFusionSupported(const HloInstruction & instr)119 bool IsFusionSupported(const HloInstruction& instr) {
120 // Support only kLoop fusion now.
121 if (!instr.IsLoopFusion()) {
122 return false;
123 }
124
125 // Cannot support fusion who has multiple output types, because the
126 // concatenate (inserted for horizontal fusion) requires the same type
127 // for all of its operands.
128 auto outputs = GetOutputsOfFusion(instr);
129 CHECK(!outputs.empty());
130 const HloInstruction* first_output = outputs[0];
131 for (size_t i = 1; i < outputs.size(); ++i) {
132 if (first_output->shape().element_type() !=
133 outputs[i]->shape().element_type()) {
134 return false;
135 }
136 }
137
138 return true;
139 }
140
141 // Returns whether `instr` is a profitable candidate to be horizontally fused.
142 // Since the primary benefit of horizontal fusion comes from reducing the
143 // kernel launch overhead, we want to exclude the instructions with
144 // insignificant kernel launch overhead. In other words, we exclude instructions
145 // if their computation latencies are longer than launch latencies. We estimate
146 // the computation latency of a given instruction by its shapes and the
147 // instruction count in its fused computation. We roughly observe that if a
148 // fusion instruction has shapes smaller than `kShapeThreshold` and has fewer
149 // instructions than `kInstrCountThreshold`, it is launch-latency-bound and
150 // profitable by horizontal fusion.
IsProfitableFusionCandidate(const HloInstruction & instr)151 bool IsProfitableFusionCandidate(const HloInstruction& instr) {
152 CHECK(instr.opcode() == HloOpcode::kFusion);
153 constexpr int64 kShapeThreshold = 128 * 2048;
154 constexpr int64 kInstrCountThreshold = 30;
155 auto root = instr.fused_expression_root();
156
157 // Too large shapes are not easily profitable.
158 if (root->opcode() == HloOpcode::kTuple) {
159 // Since all output shapes are the same, use the first shape as the
160 // representative.
161 auto shape = root->operand(0)->shape();
162 if (ShapeUtil::ElementsIn(shape) > kShapeThreshold) {
163 return false;
164 }
165 } else {
166 auto shape = root->shape();
167 if (ShapeUtil::ElementsIn(shape) > kShapeThreshold) {
168 return false;
169 }
170 }
171
172 // Having too many instructions is not easily profitable.
173 if (instr.fused_instruction_count() > kInstrCountThreshold) {
174 return false;
175 }
176
177 // We can emit DUS in-place, horizontally fusing it makes the emitter no
178 // longer recognize that it can be done in-place. This creates much slower
179 // code. This restriction could be lifted if buffer assignment would recognize
180 // that the DUS can be done in-place even inside of a horizontal fusion.
181 if (root->opcode() == HloOpcode::kDynamicUpdateSlice) {
182 return false;
183 }
184
185 return true;
186 }
187
188 // Returns whether `fusion_instr` has only row-major layouts.
189 // The horizontal fusion excludes computations with non-row-major layouts,
190 // because fusing computations with different layouts can result in uncoalesced
191 // memory accesses and cause great performance overhead.
HasOnlyRowMajorLayout(const HloInstruction & fusion_instr)192 bool HasOnlyRowMajorLayout(const HloInstruction& fusion_instr) {
193 CHECK(fusion_instr.opcode() == HloOpcode::kFusion);
194 auto instrs = fusion_instr.fused_instructions_computation()->instructions();
195 for (auto instr : instrs) {
196 if (instr->shape().layout().format() != DENSE) {
197 continue;
198 }
199 if (!LayoutUtil::IsMonotonicWithDim0Major(instr->shape().layout())) {
200 return false;
201 }
202 }
203 return true;
204 }
205
Initialize(HloInstruction * consumer)206 void HorizontalLoopFusionImpl::FusionCandidates::Initialize(
207 HloInstruction* consumer) {
208 // First, find out all fusion instructions. We will filter out
209 // unsupported/non-profitable cases below.
210 absl::flat_hash_set<HloInstruction*> fusion_instrs;
211 for (auto opnd : consumer->operands()) {
212 auto predecessor = opnd->LatestNonGteAncestor();
213 if (predecessor->opcode() == HloOpcode::kFusion) {
214 fusion_instrs.insert(predecessor);
215 }
216 }
217
218 for (auto instr : fusion_instrs) {
219 if (!IsFusionSupported(*instr)) {
220 VLOG(2) << "Reject unsupported fusion instr " << instr->ToString();
221 continue;
222 } else if (!IsConsumerTheOnlyNonRootUser(*instr, *consumer)) {
223 VLOG(2) << "Reject maybe illegal instr " << instr->ToString()
224 << "; including it may create cycles in HLO.";
225 continue;
226 } else if (!IsProfitableFusionCandidate(*instr)) {
227 VLOG(2) << "Reject may-not-be profitable fusion instr "
228 << instr->ToString();
229 continue;
230 } else if (!HasOnlyRowMajorLayout(*instr)) {
231 VLOG(2) << "Reject non-row-major fusion instr " << instr->ToString();
232 continue;
233 } else {
234 VLOG(2) << "Find a fusion candidate " << instr->ToString();
235 fusion_instrs_.push_back(instr);
236 }
237 }
238
239 // Sort `fusion_instrs` according to output types, the number of outputs,
240 // and instruction counts, because we only fuse instructions with the same
241 // number/type of outputs and whose computations have the same instruction
242 // count.
243 std::sort(
244 fusion_instrs_.begin(), fusion_instrs_.end(),
245 [&](const HloInstruction* a, const HloInstruction* b) {
246 if (GetUniqueOutputTypeOfFusion(*a) !=
247 GetUniqueOutputTypeOfFusion(*b)) {
248 return GetUniqueOutputTypeOfFusion(*a) <
249 GetUniqueOutputTypeOfFusion(*b);
250 } else if (GetOutputSizeOfFusion(*a) != GetOutputSizeOfFusion(*b)) {
251 return GetOutputSizeOfFusion(*a) < GetOutputSizeOfFusion(*b);
252 } else {
253 return a->fused_instruction_count() < b->fused_instruction_count();
254 }
255 });
256 }
257
258 // Gets a next span of fusion instructions to be fused.
259 absl::Span<HloInstruction*>
GetNextSpanOfFusions()260 HorizontalLoopFusionImpl::FusionCandidates::GetNextSpanOfFusions() {
261 if (pos_ >= fusion_instrs_.size()) {
262 return absl::Span<HloInstruction*>();
263 }
264
265 // Fusing too many computations at a time may not be easily profitable and
266 // may increase compile time due to large kernels. Set a limit to it.
267 constexpr int64 kMaxFusionBatchSize = 32;
268 // CUDA has a parameter size limit of ~4k bytes.
269 constexpr int64 kMaxCudaParamSize = 4000;
270 size_t accum_io_size = 0;
271 auto reach_max_fusion_batch_size = [&](size_t left, size_t right) -> bool {
272 if (right - left >= kMaxFusionBatchSize) {
273 return true;
274 }
275
276 accum_io_size += fusion_instrs_.at(right)->fused_parameters().size() +
277 GetOutputSizeOfFusion(*fusion_instrs_.at(right));
278
279 if (accum_io_size * 8 >= kMaxCudaParamSize) {
280 return true;
281 }
282
283 return false;
284 };
285
286 size_t left = pos_;
287 size_t right = pos_ + 1;
288 size_t first_output_size = GetOutputSizeOfFusion(*fusion_instrs_[left]);
289 PrimitiveType first_output_type =
290 GetUniqueOutputTypeOfFusion(*fusion_instrs_[left]);
291 for (; right < fusion_instrs_.size(); ++right) {
292 PrimitiveType cur_output_type =
293 GetUniqueOutputTypeOfFusion(*fusion_instrs_[right]);
294 if (first_output_type != cur_output_type) {
295 // Cannot fuse computations who have multiple output types.
296 break;
297 } else if (first_output_size !=
298 GetOutputSizeOfFusion(*fusion_instrs_[right])) {
299 // Cannot fuse computations who have different numbers of outputs.
300 break;
301 } else if (fusion_instrs_[left]->fused_instruction_count() !=
302 fusion_instrs_[right]->fused_instruction_count()) {
303 // Do not fuse computations of different instruction counts as it may
304 // introduce control divergence. This is a very simple heuristic to avoid
305 // fusing computations with too much discrepancy and we may improve it
306 // when the needs arise.
307 break;
308 } else if (reach_max_fusion_batch_size(left, right)) {
309 // Hit max fusion batch size.
310 break;
311 }
312 }
313
314 pos_ = right;
315 return absl::MakeSpan(fusion_instrs_).subspan(left, right - left);
316 }
317
CreateFusedComputation(absl::Span<HloInstruction * > fused_fusion_instrs,std::unique_ptr<HloComputation> * uniq_computation,std::vector<HloInstruction * > * bound_operands)318 Status HorizontalLoopFusionImpl::CreateFusedComputation(
319 absl::Span<HloInstruction*> fused_fusion_instrs,
320 std::unique_ptr<HloComputation>* uniq_computation,
321 std::vector<HloInstruction*>* bound_operands) {
322 // First, build a computation with only params.
323 HloComputation::Builder b("horizontally_fused_computation");
324 size_t fused_comp_param_id = 0;
325 for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
326 auto old_params = fused_fusion_instrs[i]->fused_parameters();
327 for (size_t j = 0; j < old_params.size(); ++j) {
328 auto bound_opnd = fused_fusion_instrs[i]->mutable_operand(j);
329 // in a form of param_i_j
330 b.AddInstruction(HloInstruction::CreateParameter(
331 fused_comp_param_id++, bound_opnd->shape(),
332 absl::StrCat("param_", i, "_", j)));
333 bound_operands->push_back(bound_opnd);
334 }
335 }
336 // Always create a dummy tuple instruction to serve as the root of the
337 // computation, as the existence of a root instruction is required by the
338 // HloComputation. The real root instruction will replace it below.
339 auto dummy_root = b.AddInstruction(
340 HloInstruction::CreateTuple(std::vector<HloInstruction*>{}));
341 *uniq_computation = b.Build(dummy_root);
342 auto* comp = uniq_computation->get();
343
344 // Preparing clone_map, which maps old operand to new operand.
345 absl::flat_hash_map<HloInstruction*, HloInstruction*> clone_map;
346 size_t new_param_id = 0;
347 for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
348 auto old_params = fused_fusion_instrs[i]->fused_parameters();
349 for (size_t j = 0; j < old_params.size(); ++j) {
350 auto old_param = old_params[j];
351 auto new_param = comp->parameter_instruction(new_param_id++);
352 clone_map.insert({old_param, new_param});
353 }
354 }
355
356 // Clone every fused computation.
357 for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
358 auto def_to_use_order = fused_fusion_instrs[i]
359 ->fused_instructions_computation()
360 ->MakeInstructionPostOrder();
361 for (auto old_instr : def_to_use_order) {
362 if (old_instr->opcode() == HloOpcode::kParameter) {
363 // Parameters have been created.
364 continue;
365 }
366 std::vector<HloInstruction*> new_opnds;
367 for (auto old_opnd : old_instr->operands()) {
368 CHECK(clone_map.find(old_opnd) != clone_map.end());
369 new_opnds.push_back(clone_map[old_opnd]);
370 }
371 auto new_instr = comp->AddInstruction(
372 old_instr->CloneWithNewOperands(old_instr->shape(), new_opnds));
373 clone_map.insert({old_instr, new_instr});
374 }
375 }
376
377 std::vector<HloInstruction*> concated_outputs;
378 // Since we require each fusion to have the same number of outputs, we can
379 // simply use the first fusion as the representative for output size.
380 size_t fused_instr_output_size =
381 GetOutputSizeOfFusion(*fused_fusion_instrs[0]);
382 for (size_t i = 0; i < fused_instr_output_size; ++i) {
383 std::vector<HloInstruction*> reshapes(fused_fusion_instrs.size());
384 for (size_t j = 0; j < fused_fusion_instrs.size(); ++j) {
385 auto old_output = GetOutputsOfFusion(*fused_fusion_instrs[j])[i];
386 auto new_output = clone_map[old_output];
387 TF_ASSIGN_OR_RETURN(
388 reshapes[j],
389 MakeReshapeHlo(ShapeUtil::MakeShapeWithLayout(
390 new_output->shape().element_type(),
391 {ShapeUtil::ElementsIn(new_output->shape())},
392 /*minor_to_major=*/std::vector<int64>(1, 0)),
393 new_output));
394 }
395 TF_ASSIGN_OR_RETURN(auto concated_output, MakeConcatHlo(reshapes, 0));
396 concated_outputs.push_back(concated_output);
397 }
398
399 // Make slices of outputs.
400 std::vector<HloInstruction*> output_slices(concated_outputs.size() *
401 fused_fusion_instrs.size());
402 for (size_t i = 0; i < concated_outputs.size(); ++i) {
403 auto concated_output = concated_outputs[i];
404 int64 slice_start = 0;
405 // Create a slice per fused computation.
406 for (size_t j = 0; j < fused_fusion_instrs.size(); ++j) {
407 auto old_output = GetOutputsOfFusion(*fused_fusion_instrs[j])[i];
408 auto shape = old_output->shape();
409 int64 slice_limit = slice_start + ShapeUtil::ElementsIn(shape);
410 TF_ASSIGN_OR_RETURN(
411 output_slices[concated_outputs.size() * j + i],
412 MakeSliceHlo(concated_output, {slice_start}, {slice_limit},
413 /*strides=*/{1}));
414 slice_start = slice_limit;
415 }
416 }
417
418 // Make a tuple of output_slices.
419 auto tuple = comp->AddInstruction(HloInstruction::CreateTuple(output_slices));
420 comp->set_root_instruction(tuple, /*accept_different_shape=*/true);
421 TF_RETURN_IF_ERROR(comp->RemoveInstruction(dummy_root));
422
423 return Status::OK();
424 }
425
Fuse(absl::Span<HloInstruction * > fused_fusion_instrs)426 Status HorizontalLoopFusionImpl::Fuse(
427 absl::Span<HloInstruction*> fused_fusion_instrs) {
428 // Fuse fused_fusion_instrs and replace them with the new fused computation.
429 std::unique_ptr<HloComputation> uniq_computation;
430 std::vector<HloInstruction*> bound_operands;
431 TF_RETURN_IF_ERROR(CreateFusedComputation(
432 fused_fusion_instrs, &uniq_computation, &bound_operands));
433 auto fused_comp = computation_->parent()->AddEmbeddedComputation(
434 std::move(uniq_computation));
435 auto hori_fusion_instr =
436 computation_->AddInstruction(HloInstruction::CreateFusion(
437 fused_comp->root_instruction()->shape(),
438 HloInstruction::FusionKind::kInput, bound_operands, fused_comp));
439 fused_comp->SetFusionInstruction(hori_fusion_instr);
440
441 // Insert bitcasts and replace corresponding users. Note that we do not insert
442 // the bitcasts in the fused computation as it does not fit into the slice
443 // input fusion pattern. However, inserting bitcasts outside the fused
444 // computation creates no performance cost.
445 size_t total_output_id = 0;
446 for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
447 std::vector<HloInstruction*> bitcasts;
448 auto fused_instr = fused_fusion_instrs[i];
449 auto num_outputs = GetOutputSizeOfFusion(*fused_instr);
450 for (size_t j = 0; j < num_outputs; ++j) {
451 auto output = GetOutputsOfFusion(*fused_instr)[j];
452 TF_ASSIGN_OR_RETURN(auto gep, MakeGetTupleElementHlo(hori_fusion_instr,
453 total_output_id++));
454 bitcasts.push_back(computation_->AddInstruction(
455 HloInstruction::CreateBitcast(output->shape(), gep)));
456 }
457 auto bitcast_or_tuple = (bitcasts.size() == 1)
458 ? bitcasts.at(0)
459 : computation_->AddInstruction(
460 HloInstruction::CreateTuple(bitcasts));
461 TF_RETURN_IF_ERROR(
462 computation_->ReplaceInstruction(fused_instr, bitcast_or_tuple));
463 }
464
465 return Status::OK();
466 }
467
Run()468 StatusOr<bool> HorizontalLoopFusionImpl::Run() {
469 bool changed = false;
470 XLA_VLOG_LINES(3, computation_->ToString());
471
472 // Using def-to-use order is sound since we do not modify users.
473 std::vector<HloInstruction*> def_to_use_order =
474 computation_->MakeInstructionPostOrder();
475 for (size_t i = 0; i < def_to_use_order.size(); ++i) {
476 auto consumer = def_to_use_order[i];
477 HorizontalLoopFusionImpl::FusionCandidates fusion_candidates(consumer);
478 while (true) {
479 auto fusions = fusion_candidates.GetNextSpanOfFusions();
480 if (fusions.empty()) {
481 break;
482 } else if (fusions.size() == 1) {
483 // Skip; there is just one fused_instr.
484 continue;
485 }
486
487 changed = true;
488 TF_RETURN_IF_ERROR(Fuse(fusions));
489 }
490 }
491
492 return changed;
493 }
494
495 } // namespace
496
RunOnComputation(HloComputation * computation)497 StatusOr<bool> GpuHorizontalLoopFusion::RunOnComputation(
498 HloComputation* computation) {
499 HorizontalLoopFusionImpl horizontal_fusion_impl(computation);
500 return horizontal_fusion_impl.Run();
501 }
502
Run(HloModule * module)503 StatusOr<bool> GpuHorizontalLoopFusion::Run(HloModule* module) {
504 bool changed = false;
505 VLOG(2) << "Run horizontal fusion.";
506 for (auto* comp : module->MakeNonfusionComputations()) {
507 TF_ASSIGN_OR_RETURN(changed, RunOnComputation(comp));
508 }
509
510 return changed;
511 }
512
513 } // namespace gpu
514 } // namespace xla
515