1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h"
17
18 #include <stdint.h>
19
20 #include <memory>
21 #include <vector>
22
23 #include "absl/algorithm/container.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "tensorflow/compiler/xla/debug_options_flags.h"
26 #include "tensorflow/compiler/xla/layout_util.h"
27 #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
28 #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
29 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
30 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
31 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
32 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
33 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
34 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
35 #include "tensorflow/compiler/xla/shape_util.h"
36 #include "tensorflow/core/platform/types.h"
37
38 namespace xla {
39 namespace gpu {
40
41 namespace {
42
IsProfitableOperand(HloInstruction * instr)43 bool IsProfitableOperand(HloInstruction* instr) {
44 // kConstant instruction will not have memory reads, so it won't be a profit
45 // source. Skip them.
46 if (instr->opcode() == HloOpcode::kConstant &&
47 ShapeUtil::IsEffectiveScalar(instr->shape())) {
48 return false;
49 }
50 return true;
51 }
52
LegalToFuse(HloInstruction * instr1,HloInstruction * instr2)53 bool LegalToFuse(HloInstruction* instr1, HloInstruction* instr2) {
54 // If we're fusing fusions only do it if the fusion kind matches. Loop fusions
55 // merge into bigger loop fusions and input (reduce) fusions become fusions
56 // with multiple reduce outputs. We could fuse reduce and loop fusions
57 // together too (the result being an input fusion) if we find cases where this
58 // improves things. Also disable fusing standalone input-fusible reduces into
59 // loop fusions.
60 CHECK(instr1->opcode() == HloOpcode::kFusion);
61 if ((instr2->opcode() == HloOpcode::kFusion &&
62 instr1->fusion_kind() != instr2->fusion_kind()) ||
63 (IsReductionFromOrToContiguousDimensions(*instr2) &&
64 instr1->IsLoopFusion())) {
65 return false;
66 }
67 // The emitter only supports in-place DUS for fusions with a single DUS at the
68 // root. Don't sibling fuse DUS for now.
69 // TODO(b/119178699): Multi-output fusing DUS can improve performance if we
70 // share the input and output buffers and add support to the emitter.
71 if (instr1->fused_expression_root()->opcode() ==
72 HloOpcode::kDynamicUpdateSlice ||
73 (instr2->opcode() == HloOpcode::kFusion &&
74 instr2->fused_expression_root()->opcode() ==
75 HloOpcode::kDynamicUpdateSlice)) {
76 return false;
77 }
78 // Do this check last, as it may be expensive.
79 return !FusionWouldBeTooLarge(*instr1, *instr2);
80 }
81
82 // We prefer multi-output fusions over other fusions over unfused ops, because
83 // we want to preserve fusion opportunities if possible.
FusionPriority(const HloInstruction * instr)84 int FusionPriority(const HloInstruction* instr) {
85 if (instr->IsMultiOutputFusion()) {
86 return 2;
87 }
88 if (instr->opcode() == HloOpcode::kFusion) {
89 return 1;
90 }
91 return 0;
92 }
93
SelectPreferredFusionCandidate(const std::vector<HloInstruction * > candidates)94 HloInstruction* SelectPreferredFusionCandidate(
95 const std::vector<HloInstruction*> candidates) {
96 if (candidates.empty()) {
97 return nullptr;
98 }
99 return *std::max_element(
100 candidates.begin(), candidates.end(),
101 [](const HloInstruction* a, const HloInstruction* b) {
102 return FusionPriority(a) < FusionPriority(b);
103 });
104 }
105
GetProducerConsumerMultiOutputFusionCandidates(const HloInstruction * producer,const HloReachabilityMap & reachability)106 std::vector<HloInstruction*> GetProducerConsumerMultiOutputFusionCandidates(
107 const HloInstruction* producer, const HloReachabilityMap& reachability) {
108 std::vector<HloInstruction*> fusion_candidates;
109 // If there is only one user, and it is not a multi-output fusion node, this
110 // fusion possibility was already considered and rejected by the FusionMerger
111 // pass. No need to try again!
112 if (producer->user_count() == 1 &&
113 !producer->users()[0]->IsMultiOutputFusion()) {
114 return fusion_candidates;
115 }
116 for (HloInstruction* consumer : producer->users()) {
117 VLOG(3) << "Looking at producer " << producer->name()
118 << " and its consumer " << consumer->name();
119 if (!IsFusibleAsMultiOutputFusionRoot(*consumer)) {
120 VLOG(3) << "Consumer " << consumer->name()
121 << " is not eligible as multi-output fusion root.";
122 continue;
123 }
124 if (!IsProducerConsumerMultiOutputFusible(*producer, *consumer)) {
125 VLOG(3) << producer->name() << " and " << consumer->name()
126 << " are not fusible.";
127 continue;
128 }
129 // Do not fuse a producer if the other operands of the fusion are
130 // reachable from the producer, this would create a cycle.
131 auto operand_reachable_from_producer = [&](const HloInstruction* operand) {
132 // If a get-tuple-element instruction is not in the reachability
133 // map, it has been created by fusion in this pass. Simply move
134 // on to its operand, which is in the reachability map.
135 if (!reachability.IsPresent(operand) &&
136 operand->opcode() == HloOpcode::kGetTupleElement) {
137 operand = operand->operand(0);
138 }
139 CHECK(reachability.IsPresent(operand) && reachability.IsPresent(producer))
140 << "Reachability map is incomplete. This should never "
141 "happen.";
142 return producer != operand && reachability.IsReachable(producer, operand);
143 };
144 if (absl::c_any_of(consumer->operands(), operand_reachable_from_producer)) {
145 VLOG(3) << producer->name() << " would introduce a cycle when fused.";
146 continue;
147 }
148 if (FusionWouldBeTooLarge(*producer, *consumer)) {
149 VLOG(3) << producer->name() << " and " << consumer->name()
150 << " would be too large of a fusion.";
151 continue;
152 }
153 // Make sure the emitter can codegen the fusion op efficiently. We currently
154 // can have exponential time/memory requirements for emitting certain fusion
155 // ops, in which case we don't want to fuse.
156 // TODO(b/119692968): Remove this once fixed in the emitter.
157 if (FusedIrEmitter::IsFusedIrEmitterInefficient(consumer, producer)) {
158 VLOG(3) << "Fusion of " << producer->name() << " into "
159 << consumer->name()
160 << " would result in overly large code duplication.";
161 continue;
162 }
163 fusion_candidates.push_back(consumer);
164 }
165 return fusion_candidates;
166 }
167
IsSiblingFusionCandidate(const HloInstruction * instr)168 bool IsSiblingFusionCandidate(const HloInstruction* instr) {
169 if (instr->user_count() == 0) {
170 return false;
171 }
172 if (!IsFusibleAsMultiOutputFusionRoot(*instr)) {
173 return false;
174 }
175 // Check if the users of multioutput fusion is not a get-tuple-element.
176 // If this is the case, we bail out because the transformation assumes
177 // the users are get-tuple-element.
178 if (instr->IsMultiOutputFusion()) {
179 for (auto user : instr->users()) {
180 if (user->opcode() != HloOpcode::kGetTupleElement) {
181 return false;
182 }
183 }
184 }
185 return true;
186 }
187
188 } // namespace
189
RecomputeReachability()190 void GpuMultiOutputFusion::RecomputeReachability() {
191 reachability_ = HloReachabilityMap::Build(computation_);
192 }
193
FuseSiblings(HloInstruction * parent)194 bool GpuMultiOutputFusion::FuseSiblings(HloInstruction* parent) {
195 if (!IsProfitableOperand(parent)) {
196 return false;
197 }
198 bool changed = false;
199 std::vector<HloInstruction*> siblings = parent->users();
200 // Sort the siblings such that multi-output fusion ops occur first, followed
201 // by fusion ops, followed by unfused ops.
202 absl::c_stable_sort(siblings,
203 [](const HloInstruction* a, const HloInstruction* b) {
204 return FusionPriority(a) > FusionPriority(b);
205 });
206 for (auto i = siblings.begin(); i != siblings.end();) {
207 VLOG(3) << "Considering " << (*i)->name();
208 if ((*i)->opcode() != HloOpcode::kFusion || !IsSiblingFusionCandidate(*i)) {
209 ++i;
210 continue;
211 }
212 for (auto j = i + 1; j != siblings.end();) {
213 VLOG(3) << "Considering " << (*i)->name() << " and " << (*j)->name();
214 if (!IsSiblingFusionCandidate(*j) || reachability_->IsConnected(*i, *j) ||
215 !ShapesCompatibleForMultiOutputFusion(*(*i), *(*j)) ||
216 !LegalToFuse(*i, *j)) {
217 ++j;
218 continue;
219 }
220 if (!ConsumeFuel(name(), [&] {
221 return absl::StrFormat("Not fusing siblings %s and %s.",
222 (*i)->name(), (*j)->name());
223 })) {
224 ++j;
225 continue;
226 }
227 VLOG(2) << "Fuse siblings " << (*i)->name() << " and " << (*j)->name();
228 HloInstruction* remaining = *i;
229 HloInstruction* fused = *j;
230 if (fused->opcode() == HloOpcode::kFusion) {
231 remaining->MergeFusionInstructionIntoMultiOutput(fused);
232 } else {
233 remaining->FuseInstructionIntoMultiOutput(fused);
234 CHECK_EQ(0, fused->user_count());
235 TF_CHECK_OK(computation_->RemoveInstruction(fused));
236 }
237 changed = true;
238 siblings.erase(j);
239 RecomputeReachability();
240 }
241 ++i;
242 }
243 return changed;
244 }
245
DoMultiOutputFusion()246 StatusOr<bool> GpuMultiOutputFusion::DoMultiOutputFusion() {
247 bool changed = false;
248 RecomputeReachability();
249 std::vector<HloInstruction*> defs_before_uses =
250 computation_->MakeInstructionPostOrder();
251
252 auto dump_fusion_state = [&] {
253 if (computation_->parent()
254 ->config()
255 .debug_options()
256 .xla_dump_fusion_visualization()) {
257 TF_RETURN_IF_ERROR(
258 RegisterFusionState(*computation_, "GpuMultiOutputFusion"));
259 }
260 return Status::OK();
261 };
262
263 while (!defs_before_uses.empty()) {
264 // Traverse the HLO in uses-before-defs order by removing instruction from
265 // the back of the vector.
266 HloInstruction* producer = defs_before_uses.back();
267 defs_before_uses.pop_back();
268 // Never multi-output fuse constants. To the extent that we want to fuse
269 // constants, that should be handled by the regular fusion pass.
270 if (producer->opcode() == HloOpcode::kConstant) {
271 VLOG(3) << producer->name() << " is a constant.";
272 continue;
273 }
274 // First, fuse the consumer ops of the current op, which are siblings.
275 if (FuseSiblings(/*parent=*/producer)) {
276 changed = true;
277 }
278 // Second, perform producer-consumer multi-output fusion. This order will
279 // ensure that all get-tuple-element ops inserted as a by-product of
280 // multi-output fusion will occur before the current op in the order of
281 // traversal, and hence, not get into the way of subsequent fusion attempts.
282 const auto candidates = GetProducerConsumerMultiOutputFusionCandidates(
283 producer, *reachability_);
284 auto* consumer_for_fusion = SelectPreferredFusionCandidate(candidates);
285 if (consumer_for_fusion == nullptr) {
286 continue;
287 }
288 if (!ConsumeFuel(name(), [&] {
289 return absl::StrFormat("Not fusing %s and %s.", producer->name(),
290 consumer_for_fusion->name());
291 })) {
292 continue;
293 }
294 changed = true;
295 if (consumer_for_fusion->opcode() == HloOpcode::kFusion) {
296 VLOG(2) << "Fuse producer " << producer->name() << " into its consumer "
297 << consumer_for_fusion->name();
298 if (producer->opcode() == HloOpcode::kFusion) {
299 consumer_for_fusion->MergeFusionInstructionIntoMultiOutput(producer);
300 } else {
301 consumer_for_fusion->FuseInstructionIntoMultiOutput(producer);
302 CHECK_EQ(0, producer->user_count());
303 TF_CHECK_OK(computation_->RemoveInstruction(producer));
304 }
305
306 TF_RETURN_IF_ERROR(dump_fusion_state());
307 RecomputeReachability();
308 continue;
309 }
310 HloInstruction* input_fusion =
311 computation_->AddInstruction(HloInstruction::CreateFusion(
312 consumer_for_fusion->shape(),
313 ChooseFusionKind(*producer, *consumer_for_fusion),
314 consumer_for_fusion));
315 VLOG(2) << "Fuse producer " << producer->name() << " and its consumer "
316 << consumer_for_fusion->name() << " into " << input_fusion->name();
317 TF_CHECK_OK(
318 computation_->ReplaceInstruction(consumer_for_fusion, input_fusion));
319 if (producer->opcode() == HloOpcode::kFusion) {
320 input_fusion->MergeFusionInstructionIntoMultiOutput(producer);
321 } else {
322 input_fusion->FuseInstructionIntoMultiOutput(producer);
323 CHECK_EQ(0, producer->user_count());
324 TF_CHECK_OK(computation_->RemoveInstruction(producer));
325 }
326
327 TF_RETURN_IF_ERROR(dump_fusion_state());
328 RecomputeReachability();
329 }
330 return changed;
331 }
332
Run(HloModule * module)333 StatusOr<bool> GpuMultiOutputFusion::Run(HloModule* module) {
334 bool changed = false;
335 for (auto* computation : module->MakeNonfusionComputations()) {
336 computation_ = computation;
337 TF_ASSIGN_OR_RETURN(bool fusion_changed, DoMultiOutputFusion());
338 if (fusion_changed) {
339 changed = true;
340 }
341 }
342 return changed;
343 }
344
345 } // namespace gpu
346 } // namespace xla
347