• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
17 
18 #include <algorithm>
19 #include <iterator>
20 #include <stack>
21 #include <vector>
22 
23 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/service/instruction_fusion.h"
28 #include "tensorflow/compiler/xla/shape.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30 
31 namespace xla {
32 namespace gpu {
33 namespace {
34 
35 // The amount of shared memory a CUDA kernel can use.
36 //
37 // Stay on the conservative side, this is smaller than full 64kB, but allows
38 // some extra space for cache.
39 int64_t kSharedMemoryBudgetInBytes = 40000;
40 
IfFusedReadsElementsMultipleTimes(const HloInstruction & instr)41 bool IfFusedReadsElementsMultipleTimes(const HloInstruction& instr) {
42   CHECK_NE(instr.opcode(), HloOpcode::kFusion) << "`instr` has to be unfused.";
43   if (instr.opcode() == HloOpcode::kReduce &&
44       !IsReductionFromOrToContiguousDimensions(instr)) {
45     return true;
46   }
47   // Avoid fusing reduce-window when stride is less than window size to minimize
48   // the number of reads of the same elements.
49   if (instr.opcode() == HloOpcode::kReduceWindow) {
50     for (const auto& dim : instr.window().dimensions()) {
51       if (dim.size() > dim.stride()) {
52         return true;
53       }
54     }
55   }
56   return false;
57 }
58 
59 }  // namespace
60 
IsPhysicallyTransposing(const HloInstruction & instr)61 bool IsPhysicallyTransposing(const HloInstruction& instr) {
62   if (instr.opcode() == HloOpcode::kFusion) {
63     for (const HloInstruction* fused_instr : instr.fused_instructions()) {
64       if (IsPhysicallyTransposing(*fused_instr)) {
65         return true;
66       }
67     }
68   }
69 
70   // A fusion iterates over its output in physically-contiguous order. This
71   // applies "upwards" to operands.  Only an operator that changes an operand's
72   // physical layout can create a "bad" memory access pattern.
73   return instr.opcode() == HloOpcode::kCopy ||
74          (instr.opcode() == HloOpcode::kTranspose &&
75           !ShapeUtil::TransposeIsBitcast(instr.operand(0)->shape(),
76                                          instr.shape(), instr.dimensions()));
77 }
78 
IsReduceInputFusion(const HloInstruction & instr)79 bool IsReduceInputFusion(const HloInstruction& instr) {
80   if (instr.IsMultiOutputFusion()) {
81     for (const HloInstruction* operand :
82          instr.fused_expression_root()->operands()) {
83       if (IsReductionFromOrToContiguousDimensions(*operand)) {
84         CHECK(instr.IsInputFusion())
85             << " Multi-output fusion rooted at reduction-to-vector ops must be "
86                "of kind kInput: "
87             << instr.ToString();
88         return true;
89       }
90     }
91   } else if (instr.opcode() == HloOpcode::kFusion &&
92              IsReductionFromOrToContiguousDimensions(
93                  *instr.fused_expression_root())) {
94     CHECK(instr.IsInputFusion())
95         << " Fusion rooted at reduction-to-vector op must be of kind kInput: "
96         << instr.ToString();
97     return true;
98   }
99   return false;
100 }
101 
IsInputFusibleReduction(const HloInstruction & instr)102 bool IsInputFusibleReduction(const HloInstruction& instr) {
103   return IsReduceInputFusion(instr) ||
104          IsReductionFromOrToContiguousDimensions(instr);
105 }
106 
GetRealHeroForMultiOutputFusion(const HloInstruction & instr)107 const HloInstruction* GetRealHeroForMultiOutputFusion(
108     const HloInstruction& instr) {
109   if (instr.opcode() != HloOpcode::kFusion) {
110     return &instr;
111   }
112   auto fused_expression_root = instr.fused_expression_root();
113   if (!instr.IsMultiOutputFusion()) {
114     return fused_expression_root;
115   }
116   // If possible, we want to pick a reduction-from-or-to-contiguous-dims
117   // operand of the fusion root, because it has the most constraints.
118   for (const auto* inst : fused_expression_root->operands()) {
119     if (IsReductionFromOrToContiguousDimensions(*inst)) {
120       return inst;
121     }
122   }
123   return fused_expression_root->operands()[0];
124 }
125 
ShapesCompatibleForMultiOutputFusion(const HloInstruction & instr1,const HloInstruction & instr2)126 bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1,
127                                           const HloInstruction& instr2) {
128   // Multi-output fusion kernels share a common parallel loop. The loop
129   // dimensions are determined by instruction shapes.
130   auto get_loop_shape = [&](const HloInstruction* element_instr) {
131     // Special-case reduction-to-vector ops: The loop dimensions are determined
132     // by the shape of the first operand.
133     if (IsReductionFromOrToContiguousDimensions(*element_instr)) {
134       return element_instr->operand(0)->shape();
135     }
136     return element_instr->shape();
137   };
138 
139   // All shapes of the root tuple of multi-output fusions should agree, i.e. all
140   // root ops should have equal output shapes. An exception are
141   // reduction-to-vector ops. Here the input shapes of the reduction (first
142   // operand shape) and the reduction dimensions need to match.
143   auto* instr_1 = GetRealHeroForMultiOutputFusion(instr1);
144   auto* instr_2 = GetRealHeroForMultiOutputFusion(instr2);
145   if (IsReductionFromOrToContiguousDimensions(*instr_1) &&
146       IsReductionFromOrToContiguousDimensions(*instr_2) &&
147       !AreFusedReductionOutputsConsistent({instr_1, instr_2}, instr_1)) {
148     return false;
149   }
150   // The elementwise output shapes must be the same (including layout).
151   return ShapeUtil::EqualIgnoringElementType(get_loop_shape(instr_1),
152                                              get_loop_shape(instr_2));
153 }
154 
IsInputFusibleScatter(const HloInstruction & instr)155 bool IsInputFusibleScatter(const HloInstruction& instr) {
156   if (instr.opcode() == HloOpcode::kScatter ||
157       (instr.opcode() == HloOpcode::kFusion &&
158        instr.fusion_kind() == HloInstruction::FusionKind::kInput &&
159        instr.fused_expression_root()->opcode() == HloOpcode::kScatter)) {
160     return true;
161   }
162   return false;
163 }
164 
IsInputFusible(const HloInstruction & instr)165 bool IsInputFusible(const HloInstruction& instr) {
166   // Input fusion only handles non-elemental reduction and scatter operations.
167   return instr.IsFusible() &&
168          (IsInputFusibleReduction(instr) || IsInputFusibleScatter(instr));
169 }
170 
IsLoopFusible(const HloInstruction & instr)171 bool IsLoopFusible(const HloInstruction& instr) {
172   // Don't fuse get-tuple-element on GPU: We can, but it's slower than not
173   // fusing.  We never generate kernels for unfused GTEs.  Instead, if an
174   // unfused GTE is an input to a kernel (including a fusion kernel), we
175   // compute the address of the GTE at the top of the kernel.  Often we know the
176   // address of the GTE result statically, so we can do this without chasing any
177   // pointers.
178   return instr.IsFusible() &&
179          ((instr.IsElementwise() && instr.operand_count() > 0) ||
180           instr.opcode() == HloOpcode::kBitcast ||
181           instr.opcode() == HloOpcode::kBroadcast ||
182           instr.opcode() == HloOpcode::kConcatenate ||
183           instr.opcode() == HloOpcode::kDynamicSlice ||
184           instr.opcode() == HloOpcode::kDynamicUpdateSlice ||
185           (instr.opcode() == HloOpcode::kFusion &&
186            instr.fusion_kind() == HloInstruction::FusionKind::kLoop) ||
187           instr.opcode() == HloOpcode::kGather ||
188           instr.opcode() == HloOpcode::kIota ||
189           instr.opcode() == HloOpcode::kPad ||
190           (instr.opcode() == HloOpcode::kReduce &&
191            !IsReductionFromOrToContiguousDimensions(instr) &&
192            !instr.shape().IsTuple()) ||  // TODO(b/129089333): Don't fuse
193                                          // variadic reductions.
194           instr.opcode() == HloOpcode::kReduceWindow ||
195           instr.opcode() == HloOpcode::kReshape ||
196           instr.opcode() == HloOpcode::kReverse ||
197           instr.opcode() == HloOpcode::kSlice ||
198           instr.opcode() == HloOpcode::kConstant ||
199           instr.opcode() == HloOpcode::kTranspose);
200 }
201 
IsProducerConsumerFusible(const HloInstruction & producer,const HloInstruction & consumer)202 FusionDecision IsProducerConsumerFusible(const HloInstruction& producer,
203                                          const HloInstruction& consumer) {
204   if (!IsLoopFusible(producer)) {
205     return "the producer is not loop-fusible";
206   }
207 
208   if (!IsInputFusible(consumer) && !IsLoopFusible(consumer)) {
209     return "the consumer is not input-fusible and not loop-fusible";
210   }
211 
212   // Skip multiple output fusion. It's not yet supported.
213   if (producer.IsMultiOutputFusion()) {
214     return "the producer is not fusible as it is a multi-output fusion";
215   }
216 
217   if (CreatesNestedLoop(producer, consumer)) {
218     return "the fusion would create a nested loop";
219   }
220 
221   // Do not fuse into fusions if the resulting kernel would suffer from
222   // uncoalesced reads due to a transposed memory access pattern.
223   if (IsInputFusibleReduction(consumer) && IsPhysicallyTransposing(producer)) {
224     return "fusing the producer would break read coalescing";
225   }
226 
227   // Fuse scalar constants into loop fusion nodes. This reduces the number of
228   // parameters and makes matching scalar broadcasts easier.
229   //
230   // Don't fuse other constants: Unfused constants in GPU land can be
231   // represented as an external constant (i.e. not emitted in LLVM IR / PTX),
232   // but fused constants are handled by shrared CPU/GPU code and always emitted
233   // in the IR/PTX.  The external constant representation makes for faster
234   // compiles and significantly smaller assembly code.
235   if (producer.opcode() == HloOpcode::kConstant &&
236       (!ShapeUtil::IsEffectiveScalar(producer.shape()) ||
237        consumer.opcode() != HloOpcode::kFusion)) {
238     return "not fusing constant";
239   }
240 
241   // Make sure the new fusion obeys the in-place semantics.
242   return InstructionFusion::ShouldFuseInPlaceOp(&producer, &consumer);
243 }
244 
IsProducerConsumerMultiOutputFusible(const HloInstruction & producer,const HloInstruction & consumer)245 bool IsProducerConsumerMultiOutputFusible(const HloInstruction& producer,
246                                           const HloInstruction& consumer) {
247   // Skip multiple output fusion. It's not yet supported.
248   if (producer.IsMultiOutputFusion()) {
249     return false;
250   }
251 
252   // Allowing multi-output fusions that contain in-place operations makes code
253   // generation more difficult. For the generated loop to iterate over all
254   // outputs in parallel, it must find an iteration order that guarantees that
255   // no loop iteration writes an element of any in-place operand that is read
256   // or written by any other iteration. For example:
257   //
258   //   %fused_computation {
259   //     %param_0 = s32[4,4]{1,0} parameter(0)
260   //     ...
261   //     %updated = s32[4,4]{1,0} dynamic-update-slice(
262   //         %param_0, %add, %constant_1, %constant_0)
263   //     %transpose = s32[4,4]{0,1} transpose(%updated), dimensions={1,0}
264   //     ROOT %tuple.5 = tuple(%transpose, %updated)
265   //   }
266   //
267   // Iterating 'transpose' and 'updated' in parallel by array index is
268   // not valid, because an iteration that produces some element of 'transpose'
269   // will read from an element of 'param_0' that has been overwritten by some
270   // other iteration (writing to 'updated').
271   //
272   // To avoid these problems, we simply ban fusion altogether when the producer
273   // is in-place. (We can relax this restriction by establishing an explicit
274   // contract that describes what multi-output fusion scenarios are supported by
275   // codegen and then changing this check to allow exactly those fusions).
276   if (!HloDataflowAnalysis::GetInPlaceInputOutputPairs(&producer).empty()) {
277     return false;
278   }
279   if (!IsLoopFusible(producer) || !IsFusibleAsMultiOutputFusionRoot(consumer)) {
280     return false;
281   }
282   if (CreatesNestedLoop(producer, consumer)) {
283     return false;
284   }
285   if (!ShapesCompatibleForMultiOutputFusion(producer, consumer)) {
286     return false;
287   }
288   if (IsPhysicallyTransposing(producer)) {
289     return false;
290   }
291   return true;
292 }
293 
294 // Returns shared memory usage for a given instruction in bytes.
SharedMemoryUsageNoCache(const HloInstruction & instr)295 static int64_t SharedMemoryUsageNoCache(const HloInstruction& instr) {
296   // For now we are only fusing reductions.
297   if (instr.opcode() == HloOpcode::kReduce &&
298       IsReductionFromOrToContiguousDimensions(instr)) {
299     ReductionDimensions reduction_info =
300         GetReductionKindAndContiguousComponents(instr);
301     int64_t primitive_size = ShapeUtil::ByteSizeOfPrimitiveType(
302         instr.operand(0)->shape().element_type());
303     int num_variadic =
304         instr.shape().IsTuple() ? instr.shape().tuple_shapes_size() : 1;
305     if (reduction_info.is_row_reduction) {
306       // __shared__[32] is used for row reduction.
307       return 32 * primitive_size * num_variadic;
308     } else {
309       // __shared__[2][32][33] cache is used for column reduction ("2" comes
310       // from potential x-tiling).
311       return 2 * 32 * 33 * primitive_size * num_variadic;
312     }
313   } else if (instr.opcode() == HloOpcode::kFusion) {
314     int64_t sum = 0;
315     for (const HloInstruction* hlo :
316          instr.fused_instructions_computation()->instructions()) {
317       sum += SharedMemoryUsageNoCache(*hlo);
318     }
319     return sum;
320   }
321   // Other fused expressions for now don't need the shared memory budget.
322   return 0;
323 }
324 
SharedMemoryUsage(const HloInstruction & instr,FusionInfoCache * cache=nullptr)325 static int64_t SharedMemoryUsage(const HloInstruction& instr,
326                                  FusionInfoCache* cache = nullptr) {
327   if (!cache) {
328     return SharedMemoryUsageNoCache(instr);
329   }
330 
331   // nb: Users are only expected to call cache.Invalidate() on top-level
332   // instructions, not instructions inside fusion nodes.  Therefore we can only
333   // cache top-level instructions; it would not be valid to pass the cache to
334   // SharedMemoryUsageNoCache and use the cache *within* the fusion.
335   auto it_and_inserted = cache->shared_memory_usage.emplace(&instr, -1);
336   auto it = it_and_inserted.first;
337   auto inserted = it_and_inserted.second;
338 
339   if (inserted) {
340     it->second = SharedMemoryUsageNoCache(instr);
341   }
342   return it->second;
343 }
344 
345 // Codegen'ing unnested reductions requires a lot of registers, so a MOF
346 // combining many of those runs a high risk of spilling.
347 constexpr int64_t kMaxUnnestedReductionOutputsPerFusion = 8;
348 
349 // Returns the number of unnested reductions in the instruction output.
NumUnnestedReductionsNoCache(const HloInstruction & instr)350 static int64_t NumUnnestedReductionsNoCache(const HloInstruction& instr) {
351   if (instr.opcode() == HloOpcode::kReduce &&
352       IsReductionFromOrToContiguousDimensions(instr)) {
353     return 1;
354   }
355   if (instr.opcode() == HloOpcode::kFusion) {
356     int64_t sum = 0;
357     for (const HloInstruction* hlo :
358          instr.fused_instructions_computation()->instructions()) {
359       sum += NumUnnestedReductionsNoCache(*hlo);
360     }
361     return sum;
362   }
363   return 0;
364 }
365 
NumUnnestedReductions(const HloInstruction & instr,FusionInfoCache * cache)366 static int64_t NumUnnestedReductions(const HloInstruction& instr,
367                                      FusionInfoCache* cache) {
368   if (!cache) {
369     return NumUnnestedReductionsNoCache(instr);
370   }
371 
372   // nb: Users are only expected to call cache.Invalidate() on top-level
373   // instructions, not instructions inside fusion nodes.  Therefore we can only
374   // cache top-level instructions; it would not be valid to pass the cache to
375   // NumUnnestedReductionsNoCache and use the cache *within* the fusion.
376   auto it_and_inserted = cache->num_unnested_reductions.emplace(&instr, -1);
377   auto it = it_and_inserted.first;
378   auto inserted = it_and_inserted.second;
379 
380   if (inserted) {
381     it->second = NumUnnestedReductionsNoCache(instr);
382   }
383   return it->second;
384 }
385 
386 // This function limits the maximum number of operands to a fusion, and the
387 // amount of shared memory which can be consumed by the fusion.
388 //
389 // There's a cap on how many parameters we can pass to a CUDA kernel, but
390 // exactly what that limit is hazy, as it depends on (among other things) how
391 // much GPU constant memory is in use for other purposes.
392 //
393 // Moreover, we don't even know at the point that we're running fusion how many
394 // arguments the CUDA kernel for a fusion node will have: It depends on buffer
395 // assignment, where we will decide which of the fusion's operands live in XLA's
396 // big temp buffer versus in other allocations.
397 //
398 // As a heuristic, we simply cap the number of fusion operands plus outputs at
399 // MaxOperandsAndOutputsPerFusion().  This puts an upper bound on the number of
400 // parameters to the kernel, working around the correctness problem.
401 //
402 // This limit is also often good for performance.  In a fusion with many
403 // operands, each GPU thread likely has to do a lot of work, and so possibly
404 // uses a lot of registers, thus limiting occupancy.
405 //
406 // If the fusion is a producer/consumer fusion and instr1 is the
407 // consumer and instr2 is the producer, set is_consumer_producer_fusion
408 // to true to enable more fusion.
FusionFitsInBudget(const HloInstruction & instr1,const HloInstruction & instr2,bool is_consumer_producer_fusion,FusionInfoCache * cache)409 FusionDecision FusionFitsInBudget(const HloInstruction& instr1,
410                                   const HloInstruction& instr2,
411                                   bool is_consumer_producer_fusion,
412                                   FusionInfoCache* cache /*=nullptr*/) {
413   if (SharedMemoryUsage(instr1, cache) + SharedMemoryUsage(instr2, cache) >
414       kSharedMemoryBudgetInBytes) {
415     return FusionDecision{}
416            << "shared memory usage would be over the budget of "
417            << kSharedMemoryBudgetInBytes << "B";
418   }
419 
420   if (NumUnnestedReductions(instr1, cache) +
421           NumUnnestedReductions(instr2, cache) >
422       kMaxUnnestedReductionOutputsPerFusion) {
423     return FusionDecision{} << "over " << kMaxUnnestedReductionOutputsPerFusion
424                             << " unnested reductions in fusion";
425   }
426 
427   // Compute the number of outputs of the (possibly multi-output) fusion node
428   // we're considering creating.
429   //
430   // This isn't precise; we may be off by one if
431   //  - We're creating a multi-output fusion out of two non-MOFs.  Creating a
432   //    MOF adds a new buffer, namely, the tuple buffer.
433   //  - We're merging two MOFs.  In this case, we should count the tuple buffer
434   //    only once.
435   //  - WLOG there's an edge from `a` to `b` and `b` is the only consumer of
436   //    `a`.  In this case the result of `a` is not part of the output of the
437   //    fusion.
438   //
439   // But because this is a heuristic and our limit
440   // MaxOperandsAndOutputsPerFusion() is a large value (so +/- 1 doesn't make a
441   // big difference), we ignore this small inaccuracy in favor of simplicity.
442   int64_t num_output_buffers = ShapeUtil::SubshapeCount(instr1.shape()) +
443                                ShapeUtil::SubshapeCount(instr2.shape());
444 
445   // The new fusion will have no more operands and outputs than
446   //   producer_operands + consumer_operands - 1 + num_output_buffers
447   // (minus one because we may be fusing a producer->consumer edge between `a`
448   // and `b`).
449   //
450   // This fact may be enough to let us avoid having to compute the true total
451   // number of operands, which can be expensive.
452   if (instr1.operand_count() + instr2.operand_count() - 1 +
453           num_output_buffers <=
454       MaxOperandsAndOutputsPerFusion()) {
455     return {};
456   } else {
457     VLOG(5) << "Operand count of "
458             << "(" << instr1.ToString() << " ) = " << instr1.operand_count()
459             << " and ( " << instr2.ToString()
460             << " ) = " << instr2.operand_count()
461             << " and num_output_buffers = " << num_output_buffers
462             << " is bigger than the bound of "
463             << MaxOperandsAndOutputsPerFusion();
464   }
465 
466   // Compute the precise number of operands to the new fusion.
467   absl::flat_hash_set<const HloInstruction*> operands(instr1.operands().begin(),
468                                                       instr1.operands().end());
469   operands.insert(instr2.operands().begin(), instr2.operands().end());
470   // If there's an edge between `a` and `b`, don't count it: We're fusing that
471   // producer -> consumer relationship.
472   operands.erase(&instr1);
473   operands.erase(&instr2);
474 
475   // If we generate the same numbers of inputs and outputs as
476   // before, it won't be bigger after fusion. So accept the fusion.
477   // As this is a consumer_producer fusion, this does not change the
478   // consumer numbers of output. So no need to check it.
479   if (is_consumer_producer_fusion &&
480       operands.size() <= instr1.operands().size()) {
481     return {};
482   }
483 
484   // Does the new fusion have more operands and outputs than the max?
485   if (operands.size() + num_output_buffers > MaxOperandsAndOutputsPerFusion()) {
486     return "Number of operands and output buffers is larger than allowed "
487            "budget per fusion";
488   }
489   return {};
490 }
491 
CreatesNestedLoop(const HloInstruction & producer,const HloInstruction & consumer)492 bool CreatesNestedLoop(const HloInstruction& producer,
493                        const HloInstruction& consumer) {
494   // If producer does not have an instruction that codegens a loop then there is
495   // nothing to do.
496   auto producer_has_loop_codegen = [&](const HloInstruction& instr) {
497     if (producer.opcode() != HloOpcode::kFusion) {
498       return IfFusedReadsElementsMultipleTimes(producer);
499     }
500     for (const auto& instr : producer.fused_instructions()) {
501       if (IfFusedReadsElementsMultipleTimes(*instr)) {
502         return true;
503       }
504     }
505     return false;
506   };
507   if (!producer_has_loop_codegen(producer)) {
508     return false;
509   }
510 
511   // If consumer is a non-fusion instruction then we have to check if it
512   // generates a loop.
513   if (consumer.opcode() != HloOpcode::kFusion) {
514     return IfFusedReadsElementsMultipleTimes(consumer);
515   }
516 
517   // If consumer is a fusion then we have to check if the output of producer is
518   // used directly or indirectly as an input to an HLO instruction that
519   // generates a loop, i.e. there is a path in the graph from an operand
520   // corresponding to the producer to an HLO instruction generating a loop in
521   // the consumer.
522   for (const HloInstruction* operand : consumer.operands()) {
523     if (operand != &producer) {
524       continue;
525     }
526 
527     const HloInstruction* root =
528         consumer.fused_instructions_computation()->parameter_instruction(
529             consumer.operand_index(operand));
530 
531     std::stack<const HloInstruction*> dfs;
532     dfs.push(root);
533     absl::flat_hash_set<const HloInstruction*> visited;
534     while (!dfs.empty()) {
535       const HloInstruction* cur = dfs.top();
536       dfs.pop();
537 
538       if (visited.contains(cur)) {
539         continue;
540       }
541       visited.insert(cur);
542 
543       if (IfFusedReadsElementsMultipleTimes(*cur)) {
544         return true;
545       }
546       for (const auto& user : cur->users()) {
547         if (visited.contains(user)) {
548           continue;
549         }
550         dfs.push(user);
551       }
552     }
553   }
554   return false;
555 }
556 
IsFusibleAsMultiOutputFusionRoot(const HloInstruction & instr)557 bool IsFusibleAsMultiOutputFusionRoot(const HloInstruction& instr) {
558   // We can fuse reduces and loop fusions. Elementwise instructions can be fused
559   // with any other instruction.
560   // Note that scatter cannot be the root of a multi-output fusion because
561   // its emitter doesn't support it.
562 
563   return instr.IsFusible() &&
564          (IsInputFusibleReduction(instr) ||
565           instr.IsLoopFusion() ||  // TODO(b/130013493): Use IsLoopFusible here.
566           instr.IsElementwise());
567 }
568 
ChooseFusionKind(const HloInstruction &,const HloInstruction & consumer)569 HloInstruction::FusionKind ChooseFusionKind(const HloInstruction& /*producer*/,
570                                             const HloInstruction& consumer) {
571   return IsInputFusible(consumer) ? HloInstruction::FusionKind::kInput
572                                   : HloInstruction::FusionKind::kLoop;
573 }
574 
IsConsumerTheOnlyNonRootUser(const HloInstruction & instr,const HloInstruction & consumer)575 bool IsConsumerTheOnlyNonRootUser(const HloInstruction& instr,
576                                   const HloInstruction& consumer) {
577   return absl::c_all_of(instr.users(), [&](const HloInstruction* user) {
578     if (user->opcode() == HloOpcode::kGetTupleElement) {
579       // Skip GTE.
580       return IsConsumerTheOnlyNonRootUser(*user, consumer);
581     }
582     if (user == &consumer) {
583       // `user` is `consumer`.
584       return true;
585     }
586     if (user == user->parent()->root_instruction()) {
587       // Consumed by ROOT.
588       return true;
589     }
590     return false;
591   });
592 }
593 
GetInstrCountOfFusible(const HloInstruction & instr)594 size_t GetInstrCountOfFusible(const HloInstruction& instr) {
595   if (instr.opcode() != HloOpcode::kFusion) {
596     return 1;
597   } else {
598     return instr.fused_instruction_count();
599   }
600 }
601 
GetOutputsOfFusible(const HloInstruction & instr)602 absl::InlinedVector<const HloInstruction*, 2> GetOutputsOfFusible(
603     const HloInstruction& instr) {
604   if (instr.opcode() != HloOpcode::kFusion) {
605     return {&instr};
606   }
607 
608   HloInstruction* root = instr.fused_expression_root();
609   if (root->opcode() != HloOpcode::kTuple) {
610     return {root};
611   } else {
612     auto v = root->operands();
613     return absl::InlinedVector<const HloInstruction*, 2>(v.begin(), v.end());
614   }
615 }
616 
GetOutputSizeOfFusible(const HloInstruction & instr)617 size_t GetOutputSizeOfFusible(const HloInstruction& instr) {
618   if (!instr.IsMultiOutputFusion()) {
619     return 1;
620   }
621   const HloInstruction* root = instr.fused_expression_root();
622   return ShapeUtil::TupleElementCount(root->shape());
623 }
624 
625 }  // namespace gpu
626 }  // namespace xla
627