1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h"
17
18 #include <algorithm>
19
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/strings/str_cat.h"
22 #include "absl/types/span.h"
23 #include "tensorflow/compiler/xla/layout_util.h"
24 #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
25 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
26 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
27 #include "tensorflow/core/platform/errors.h"
28
29 namespace xla {
30 namespace gpu {
31
32 namespace {
33
34 // Gets the representative input shape of the multi-output fusion.
GetInputShapeForMultiOutputFusion(const HloInstruction & instr)35 Shape GetInputShapeForMultiOutputFusion(const HloInstruction& instr) {
36 // Get the HLO that determines the emitter used for lowering.
37 const HloInstruction* real_hero = GetRealHeroForMultiOutputFusion(instr);
38 if (real_hero->operands().empty()) {
39 // Simply return an empty shape if the representative node has no input
40 // operands.
41 return Shape();
42 } else {
43 return real_hero->operand(0)->shape();
44 }
45 }
46
47 class HorizontalInputFusionImpl {
48 public:
HorizontalInputFusionImpl(HloComputation * computation)49 explicit HorizontalInputFusionImpl(HloComputation* computation)
50 : computation_(computation) {}
51
~HorizontalInputFusionImpl()52 ~HorizontalInputFusionImpl() {}
53
54 StatusOr<bool> Run();
55
56 private:
57 HloComputation* computation_;
58 }; // HorizontalInputFusionImpl
59
60 // Compares one-by-one the dimensions of `shape_a` and `shape_b` from left to
61 // right.
CompareShapeDimsFromLeftToRight(const Shape & shape_a,const Shape & shape_b)62 bool CompareShapeDimsFromLeftToRight(const Shape& shape_a,
63 const Shape& shape_b) {
64 if (shape_a.rank() != shape_b.rank()) {
65 return shape_a.rank() < shape_b.rank();
66 }
67 auto dims_a = shape_a.dimensions();
68 auto dims_b = shape_b.dimensions();
69 for (size_t i = 0; i < dims_a.size(); ++i) {
70 if (dims_a[i] != dims_b[i]) {
71 return dims_a[i] < dims_b[i];
72 }
73 }
74 return true;
75 }
76
FindAndSortFusionCandidates(HloInstruction * consumer)77 std::vector<HloInstruction*> FindAndSortFusionCandidates(
78 HloInstruction* consumer) {
79 absl::flat_hash_set<HloInstruction*> fusion_instr_set;
80 std::vector<HloInstruction*> fusion_instrs;
81 for (HloInstruction* opnd : consumer->operands()) {
82 HloInstruction* predecessor = opnd->LatestNonGteAncestor();
83 // Find out the input fusion instructions whose only consumer is `consumer`.
84 // This guarantees that fusing these candidates will never create cycles, as
85 // there is no back edge.
86 if (IsInputFusibleReduction(*predecessor) &&
87 IsConsumerTheOnlyNonRootUser(*predecessor, *consumer)) {
88 if (fusion_instr_set.insert(predecessor).second) {
89 fusion_instrs.push_back(predecessor);
90 }
91 }
92 }
93
94 std::sort(fusion_instrs.begin(), fusion_instrs.end(),
95 [&](const HloInstruction* a, const HloInstruction* b) {
96 Shape shape_a = GetInputShapeForMultiOutputFusion(*a);
97 Shape shape_b = GetInputShapeForMultiOutputFusion(*b);
98 if (!ShapeUtil::EqualIgnoringElementType(shape_a, shape_b)) {
99 // Sort shapes according to dimensions, so that the same input
100 // shapes will be placed adjacent each other.
101 return CompareShapeDimsFromLeftToRight(shape_a, shape_b);
102 }
103 // Sort `fusion_instrs` according to instruction counts, because
104 // we'd like to fuse together computations of similar sizes.
105 return GetInstrCountOfFusible(*a) < GetInstrCountOfFusible(*b);
106 });
107
108 return fusion_instrs;
109 }
110
Run()111 StatusOr<bool> HorizontalInputFusionImpl::Run() {
112 bool changed = false;
113 XLA_VLOG_LINES(3, computation_->ToString());
114
115 // Using def-to-use order is sound since we do not modify users.
116 std::vector<HloInstruction*> def_to_use_order =
117 computation_->MakeInstructionPostOrder();
118 for (HloInstruction* consumer : def_to_use_order) {
119 auto candidates = FindAndSortFusionCandidates(consumer);
120 if (candidates.size() <= 1) {
121 continue;
122 }
123
124 // Convert candidates into fusions if needed.
125 for (size_t j = 0; j < candidates.size(); ++j) {
126 if (candidates[j]->opcode() != HloOpcode::kFusion) {
127 TF_ASSIGN_OR_RETURN(
128 HloInstruction * fusion_instr,
129 MakeFusionInstruction(candidates[j],
130 HloInstruction::FusionKind::kInput));
131 candidates[j] = fusion_instr;
132 changed = true;
133 }
134 }
135
136 size_t fusion_anchor_id = 0;
137 for (size_t j = 1; j < candidates.size(); ++j) {
138 HloInstruction* fusion_anchor = candidates[fusion_anchor_id];
139 HloInstruction* fused = candidates[j];
140 if (ShapesCompatibleForMultiOutputFusion(*fusion_anchor, *fused) &&
141 FusionFitsInBudget(*fusion_anchor, *fused)) {
142 VLOG(3) << "Fuse " << fused->ToString() << " into "
143 << fusion_anchor->ToString();
144 fusion_anchor->MergeFusionInstructionIntoMultiOutput(fused);
145 changed = true;
146 } else {
147 // Update the `fusion_anchor_id` since `fused` is either not
148 // compatible or not beneficial to be fused with current fusion anchor.
149 VLOG(3) << j - fusion_anchor_id - 1 << " instructions are fused.";
150 fusion_anchor_id = j;
151 }
152 }
153 }
154
155 return changed;
156 }
157
158 } // namespace
159
RunOnComputation(HloComputation * computation)160 StatusOr<bool> GpuHorizontalInputFusion::RunOnComputation(
161 HloComputation* computation) {
162 HorizontalInputFusionImpl horizontal_fusion_impl(computation);
163 return horizontal_fusion_impl.Run();
164 }
165
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)166 StatusOr<bool> GpuHorizontalInputFusion::Run(
167 HloModule* module,
168 const absl::flat_hash_set<absl::string_view>& execution_threads) {
169 bool changed = false;
170 VLOG(2) << "Run horizontal input fusion.";
171 for (HloComputation* comp :
172 module->MakeNonfusionComputations(execution_threads)) {
173 TF_ASSIGN_OR_RETURN(changed, RunOnComputation(comp));
174 }
175
176 return changed;
177 }
178
179 } // namespace gpu
180 } // namespace xla
181