• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/multi_output_fusion.h"
17 
18 #include <stdint.h>
19 
20 #include <memory>
21 #include <vector>
22 
23 #include "absl/algorithm/container.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "tensorflow/compiler/xla/debug_options_flags.h"
26 #include "tensorflow/compiler/xla/layout_util.h"
27 #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
28 #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
29 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
30 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
31 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
32 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
33 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
34 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
35 #include "tensorflow/compiler/xla/shape_util.h"
36 #include "tensorflow/core/platform/types.h"
37 
38 namespace xla {
39 namespace gpu {
40 
41 namespace {
42 
IsProfitableOperand(HloInstruction * instr)43 bool IsProfitableOperand(HloInstruction* instr) {
44   // kConstant instruction will not have memory reads, so it won't be a profit
45   // source. Skip them.
46   if (instr->opcode() == HloOpcode::kConstant &&
47       ShapeUtil::IsEffectiveScalar(instr->shape())) {
48     return false;
49   }
50   return true;
51 }
52 
LegalToFuse(HloInstruction * instr1,HloInstruction * instr2)53 bool LegalToFuse(HloInstruction* instr1, HloInstruction* instr2) {
54   // If we're fusing fusions only do it if the fusion kind matches. Loop fusions
55   // merge into bigger loop fusions and input (reduce) fusions become fusions
56   // with multiple reduce outputs. We could fuse reduce and loop fusions
57   // together too (the result being an input fusion) if we find cases where this
58   // improves things. Also disable fusing standalone input-fusible reduces into
59   // loop fusions.
60   CHECK(instr1->opcode() == HloOpcode::kFusion);
61   if ((instr2->opcode() == HloOpcode::kFusion &&
62        instr1->fusion_kind() != instr2->fusion_kind()) ||
63       (IsReductionFromOrToContiguousDimensions(*instr2) &&
64        instr1->IsLoopFusion())) {
65     return false;
66   }
67   // The emitter only supports in-place DUS for fusions with a single DUS at the
68   // root. Don't sibling fuse DUS for now.
69   // TODO(b/119178699): Multi-output fusing DUS can improve performance if we
70   // share the input and output buffers and add support to the emitter.
71   if (instr1->fused_expression_root()->opcode() ==
72           HloOpcode::kDynamicUpdateSlice ||
73       (instr2->opcode() == HloOpcode::kFusion &&
74        instr2->fused_expression_root()->opcode() ==
75            HloOpcode::kDynamicUpdateSlice)) {
76     return false;
77   }
78   // Do this check last, as it may be expensive.
79   return !FusionWouldBeTooLarge(*instr1, *instr2);
80 }
81 
82 // We prefer multi-output fusions over other fusions over unfused ops, because
83 // we want to preserve fusion opportunities if possible.
FusionPriority(const HloInstruction * instr)84 int FusionPriority(const HloInstruction* instr) {
85   if (instr->IsMultiOutputFusion()) {
86     return 2;
87   }
88   if (instr->opcode() == HloOpcode::kFusion) {
89     return 1;
90   }
91   return 0;
92 }
93 
SelectPreferredFusionCandidate(const std::vector<HloInstruction * > candidates)94 HloInstruction* SelectPreferredFusionCandidate(
95     const std::vector<HloInstruction*> candidates) {
96   if (candidates.empty()) {
97     return nullptr;
98   }
99   return *std::max_element(
100       candidates.begin(), candidates.end(),
101       [](const HloInstruction* a, const HloInstruction* b) {
102         return FusionPriority(a) < FusionPriority(b);
103       });
104 }
105 
GetProducerConsumerMultiOutputFusionCandidates(const HloInstruction * producer,const HloReachabilityMap & reachability)106 std::vector<HloInstruction*> GetProducerConsumerMultiOutputFusionCandidates(
107     const HloInstruction* producer, const HloReachabilityMap& reachability) {
108   std::vector<HloInstruction*> fusion_candidates;
109   // If there is only one user, and it is not a multi-output fusion node, this
110   // fusion possibility was already considered and rejected by the FusionMerger
111   // pass. No need to try again!
112   if (producer->user_count() == 1 &&
113       !producer->users()[0]->IsMultiOutputFusion()) {
114     return fusion_candidates;
115   }
116   for (HloInstruction* consumer : producer->users()) {
117     VLOG(3) << "Looking at producer " << producer->name()
118             << " and its consumer " << consumer->name();
119     if (!IsFusibleAsMultiOutputFusionRoot(*consumer)) {
120       VLOG(3) << "Consumer " << consumer->name()
121               << " is not eligible as multi-output fusion root.";
122       continue;
123     }
124     if (!IsProducerConsumerMultiOutputFusible(*producer, *consumer)) {
125       VLOG(3) << producer->name() << " and " << consumer->name()
126               << " are not fusible.";
127       continue;
128     }
129     // Do not fuse a producer if the other operands of the fusion are
130     // reachable from the producer, this would create a cycle.
131     auto operand_reachable_from_producer = [&](const HloInstruction* operand) {
132       // If a get-tuple-element instruction is not in the reachability
133       // map, it has been created by fusion in this pass. Simply move
134       // on to its operand, which is in the reachability map.
135       if (!reachability.IsPresent(operand) &&
136           operand->opcode() == HloOpcode::kGetTupleElement) {
137         operand = operand->operand(0);
138       }
139       CHECK(reachability.IsPresent(operand) && reachability.IsPresent(producer))
140           << "Reachability map is incomplete. This should never "
141              "happen.";
142       return producer != operand && reachability.IsReachable(producer, operand);
143     };
144     if (absl::c_any_of(consumer->operands(), operand_reachable_from_producer)) {
145       VLOG(3) << producer->name() << " would introduce a cycle when fused.";
146       continue;
147     }
148     if (FusionWouldBeTooLarge(*producer, *consumer)) {
149       VLOG(3) << producer->name() << " and " << consumer->name()
150               << " would be too large of a fusion.";
151       continue;
152     }
153     // Make sure the emitter can codegen the fusion op efficiently. We currently
154     // can have exponential time/memory requirements for emitting certain fusion
155     // ops, in which case we don't want to fuse.
156     // TODO(b/119692968): Remove this once fixed in the emitter.
157     if (FusedIrEmitter::IsFusedIrEmitterInefficient(consumer, producer)) {
158       VLOG(3) << "Fusion of " << producer->name() << " into "
159               << consumer->name()
160               << " would result in overly large code duplication.";
161       continue;
162     }
163     fusion_candidates.push_back(consumer);
164   }
165   return fusion_candidates;
166 }
167 
IsSiblingFusionCandidate(const HloInstruction * instr)168 bool IsSiblingFusionCandidate(const HloInstruction* instr) {
169   if (instr->user_count() == 0) {
170     return false;
171   }
172   if (!IsFusibleAsMultiOutputFusionRoot(*instr)) {
173     return false;
174   }
175   // Check if the users of multioutput fusion is not a get-tuple-element.
176   // If this is the case, we bail out because the transformation assumes
177   // the users are get-tuple-element.
178   if (instr->IsMultiOutputFusion()) {
179     for (auto user : instr->users()) {
180       if (user->opcode() != HloOpcode::kGetTupleElement) {
181         return false;
182       }
183     }
184   }
185   return true;
186 }
187 
188 }  // namespace
189 
RecomputeReachability()190 void GpuMultiOutputFusion::RecomputeReachability() {
191   reachability_ = HloReachabilityMap::Build(computation_);
192 }
193 
FuseSiblings(HloInstruction * parent)194 bool GpuMultiOutputFusion::FuseSiblings(HloInstruction* parent) {
195   if (!IsProfitableOperand(parent)) {
196     return false;
197   }
198   bool changed = false;
199   std::vector<HloInstruction*> siblings = parent->users();
200   // Sort the siblings such that multi-output fusion ops occur first, followed
201   // by fusion ops, followed by unfused ops.
202   absl::c_stable_sort(siblings,
203                       [](const HloInstruction* a, const HloInstruction* b) {
204                         return FusionPriority(a) > FusionPriority(b);
205                       });
206   for (auto i = siblings.begin(); i != siblings.end();) {
207     VLOG(3) << "Considering " << (*i)->name();
208     if ((*i)->opcode() != HloOpcode::kFusion || !IsSiblingFusionCandidate(*i)) {
209       ++i;
210       continue;
211     }
212     for (auto j = i + 1; j != siblings.end();) {
213       VLOG(3) << "Considering " << (*i)->name() << " and " << (*j)->name();
214       if (!IsSiblingFusionCandidate(*j) || reachability_->IsConnected(*i, *j) ||
215           !ShapesCompatibleForMultiOutputFusion(*(*i), *(*j)) ||
216           !LegalToFuse(*i, *j)) {
217         ++j;
218         continue;
219       }
220       if (!ConsumeFuel(name(), [&] {
221             return absl::StrFormat("Not fusing siblings %s and %s.",
222                                    (*i)->name(), (*j)->name());
223           })) {
224         ++j;
225         continue;
226       }
227       VLOG(2) << "Fuse siblings " << (*i)->name() << " and " << (*j)->name();
228       HloInstruction* remaining = *i;
229       HloInstruction* fused = *j;
230       if (fused->opcode() == HloOpcode::kFusion) {
231         remaining->MergeFusionInstructionIntoMultiOutput(fused);
232       } else {
233         remaining->FuseInstructionIntoMultiOutput(fused);
234         CHECK_EQ(0, fused->user_count());
235         TF_CHECK_OK(computation_->RemoveInstruction(fused));
236       }
237       changed = true;
238       siblings.erase(j);
239       RecomputeReachability();
240     }
241     ++i;
242   }
243   return changed;
244 }
245 
DoMultiOutputFusion()246 StatusOr<bool> GpuMultiOutputFusion::DoMultiOutputFusion() {
247   bool changed = false;
248   RecomputeReachability();
249   std::vector<HloInstruction*> defs_before_uses =
250       computation_->MakeInstructionPostOrder();
251 
252   auto dump_fusion_state = [&] {
253     if (computation_->parent()
254             ->config()
255             .debug_options()
256             .xla_dump_fusion_visualization()) {
257       TF_RETURN_IF_ERROR(
258           RegisterFusionState(*computation_, "GpuMultiOutputFusion"));
259     }
260     return Status::OK();
261   };
262 
263   while (!defs_before_uses.empty()) {
264     // Traverse the HLO in uses-before-defs order by removing instruction from
265     // the back of the vector.
266     HloInstruction* producer = defs_before_uses.back();
267     defs_before_uses.pop_back();
268     // Never multi-output fuse constants.  To the extent that we want to fuse
269     // constants, that should be handled by the regular fusion pass.
270     if (producer->opcode() == HloOpcode::kConstant) {
271       VLOG(3) << producer->name() << " is a constant.";
272       continue;
273     }
274     // First, fuse the consumer ops of the current op, which are siblings.
275     if (FuseSiblings(/*parent=*/producer)) {
276       changed = true;
277     }
278     // Second, perform producer-consumer multi-output fusion. This order will
279     // ensure that all get-tuple-element ops inserted as a by-product of
280     // multi-output fusion will occur before the current op in the order of
281     // traversal, and hence, not get into the way of subsequent fusion attempts.
282     const auto candidates = GetProducerConsumerMultiOutputFusionCandidates(
283         producer, *reachability_);
284     auto* consumer_for_fusion = SelectPreferredFusionCandidate(candidates);
285     if (consumer_for_fusion == nullptr) {
286       continue;
287     }
288     if (!ConsumeFuel(name(), [&] {
289           return absl::StrFormat("Not fusing %s and %s.", producer->name(),
290                                  consumer_for_fusion->name());
291         })) {
292       continue;
293     }
294     changed = true;
295     if (consumer_for_fusion->opcode() == HloOpcode::kFusion) {
296       VLOG(2) << "Fuse producer " << producer->name() << " into its consumer "
297               << consumer_for_fusion->name();
298       if (producer->opcode() == HloOpcode::kFusion) {
299         consumer_for_fusion->MergeFusionInstructionIntoMultiOutput(producer);
300       } else {
301         consumer_for_fusion->FuseInstructionIntoMultiOutput(producer);
302         CHECK_EQ(0, producer->user_count());
303         TF_CHECK_OK(computation_->RemoveInstruction(producer));
304       }
305 
306       TF_RETURN_IF_ERROR(dump_fusion_state());
307       RecomputeReachability();
308       continue;
309     }
310     HloInstruction* input_fusion =
311         computation_->AddInstruction(HloInstruction::CreateFusion(
312             consumer_for_fusion->shape(),
313             ChooseFusionKind(*producer, *consumer_for_fusion),
314             consumer_for_fusion));
315     VLOG(2) << "Fuse producer " << producer->name() << " and its consumer "
316             << consumer_for_fusion->name() << " into " << input_fusion->name();
317     TF_CHECK_OK(
318         computation_->ReplaceInstruction(consumer_for_fusion, input_fusion));
319     if (producer->opcode() == HloOpcode::kFusion) {
320       input_fusion->MergeFusionInstructionIntoMultiOutput(producer);
321     } else {
322       input_fusion->FuseInstructionIntoMultiOutput(producer);
323       CHECK_EQ(0, producer->user_count());
324       TF_CHECK_OK(computation_->RemoveInstruction(producer));
325     }
326 
327     TF_RETURN_IF_ERROR(dump_fusion_state());
328     RecomputeReachability();
329   }
330   return changed;
331 }
332 
Run(HloModule * module)333 StatusOr<bool> GpuMultiOutputFusion::Run(HloModule* module) {
334   bool changed = false;
335   for (auto* computation : module->MakeNonfusionComputations()) {
336     computation_ = computation;
337     TF_ASSIGN_OR_RETURN(bool fusion_changed, DoMultiOutputFusion());
338     if (fusion_changed) {
339       changed = true;
340     }
341   }
342   return changed;
343 }
344 
345 }  // namespace gpu
346 }  // namespace xla
347