• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
17 
18 #include <algorithm>
19 #include <iterator>
20 #include <stack>
21 #include <vector>
22 
23 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/shape.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29 
30 namespace xla {
31 namespace gpu {
32 namespace {
33 
34 // The amount of shared memory a CUDA kernel can use.
35 //
36 // Stay on the conservative side, this is smaller than full 64kB, but allows
37 // some extra space for cache.
38 int64_t kSharedMemoryBudgetInBytes = 40000;
39 
AppendParams(const HloInstruction & instr,std::vector<HloInstruction * > * params)40 void AppendParams(const HloInstruction& instr,
41                   std::vector<HloInstruction*>* params) {
42   if (instr.opcode() == HloOpcode::kFusion) {
43     params->insert(std::end(*params), std::begin(instr.fused_parameters()),
44                    std::end(instr.fused_parameters()));
45   } else {
46     for (HloInstruction* operand : instr.operands()) {
47       params->push_back(operand);
48     }
49   }
50 }
51 
IfFusedReadsElementsMultipleTimes(const HloInstruction & instr)52 bool IfFusedReadsElementsMultipleTimes(const HloInstruction& instr) {
53   CHECK_NE(instr.opcode(), HloOpcode::kFusion) << "`instr` has to be unfused.";
54   if (instr.opcode() == HloOpcode::kReduce &&
55       !IsReductionFromOrToContiguousDimensions(instr)) {
56     return true;
57   }
58   // Avoid fusing reduce-window when stride is less than window size to minimize
59   // the number of reads of the same elements.
60   if (instr.opcode() == HloOpcode::kReduceWindow) {
61     for (const auto& dim : instr.window().dimensions()) {
62       if (dim.size() > dim.stride()) {
63         return true;
64       }
65     }
66   }
67   return false;
68 }
69 
ExtractRelativeOrderOfNontrivialDims(const Shape & shape)70 std::vector<int64> ExtractRelativeOrderOfNontrivialDims(const Shape& shape) {
71   std::vector<int64> relative_order;
72   for (int64_t dim : LayoutUtil::MinorToMajor(shape)) {
73     if (shape.dimensions(dim) > 1) {
74       relative_order.push_back(dim);
75     }
76   }
77   // Now normalize the dimensions to values between 0 and true rank - 1.
78   std::vector<int64> sorted_dims = relative_order;
79   std::sort(sorted_dims.begin(), sorted_dims.end());
80   for (int64& dim : relative_order) {
81     int64_t sorted_index = std::distance(
82         sorted_dims.begin(),
83         std::lower_bound(sorted_dims.begin(), sorted_dims.end(), dim));
84     dim = sorted_index;
85   }
86   return relative_order;
87 }
88 
89 }  // namespace
90 
LayoutsAreReduceInputFusionFriendly(const HloInstruction & producer,const HloInstruction & reduce)91 bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer,
92                                          const HloInstruction& reduce) {
93   std::vector<HloInstruction*> params;
94   AppendParams(producer, &params);
95   AppendParams(reduce, &params);
96   int64_t max_true_rank = -1;
97   std::vector<int64> max_rank_order;
98   for (HloInstruction* param : params) {
99     if (param->shape().IsArray() &&
100         ShapeUtil::TrueRank(param->shape()) > max_true_rank) {
101       max_true_rank = ShapeUtil::TrueRank(param->shape());
102       max_rank_order = ExtractRelativeOrderOfNontrivialDims(param->shape());
103     }
104   }
105   return absl::c_all_of(params, [&](HloInstruction* param) {
106     return !param->shape().IsArray() ||
107            ShapeUtil::TrueRank(param->shape()) < max_true_rank ||
108            ExtractRelativeOrderOfNontrivialDims(param->shape()) ==
109                max_rank_order;
110   });
111 }
112 
IsReduceInputFusion(const HloInstruction & instr)113 bool IsReduceInputFusion(const HloInstruction& instr) {
114   if (instr.IsMultiOutputFusion()) {
115     for (const HloInstruction* operand :
116          instr.fused_expression_root()->operands()) {
117       if (IsReductionFromOrToContiguousDimensions(*operand)) {
118         CHECK(instr.IsInputFusion())
119             << " Multi-output fusion rooted at reduction-to-vector ops must be "
120                "of kind kInput: "
121             << instr.ToString();
122         return true;
123       }
124     }
125   } else if (instr.opcode() == HloOpcode::kFusion &&
126              IsReductionFromOrToContiguousDimensions(
127                  *instr.fused_expression_root())) {
128     CHECK(instr.IsInputFusion())
129         << " Fusion rooted at reduction-to-vector op must be of kind kInput: "
130         << instr.ToString();
131     return true;
132   }
133   return false;
134 }
135 
IsInputFusibleReduction(const HloInstruction & instr)136 bool IsInputFusibleReduction(const HloInstruction& instr) {
137   // TODO(b/129089333): Don't fuse variadic reduce.
138   if (instr.opcode() == HloOpcode::kReduce && instr.shape().IsTuple()) {
139     return false;
140   }
141 
142   return IsReduceInputFusion(instr) ||
143          IsReductionFromOrToContiguousDimensions(instr);
144 }
145 
GetRealHeroForMultiOutputFusion(const HloInstruction & instr)146 const HloInstruction* GetRealHeroForMultiOutputFusion(
147     const HloInstruction& instr) {
148   if (instr.opcode() != HloOpcode::kFusion) {
149     return &instr;
150   }
151   auto fused_expression_root = instr.fused_expression_root();
152   if (!instr.IsMultiOutputFusion()) {
153     return fused_expression_root;
154   }
155   // If possible, we want to pick a reduction-from-or-to-contiguous-dims
156   // operand of the fusion root, because it has the most constraints.
157   for (const auto* inst : fused_expression_root->operands()) {
158     if (IsReductionFromOrToContiguousDimensions(*inst)) {
159       return inst;
160     }
161   }
162   return fused_expression_root->operands()[0];
163 }
164 
ShapesCompatibleForMultiOutputFusion(const HloInstruction & instr1,const HloInstruction & instr2)165 bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1,
166                                           const HloInstruction& instr2) {
167   // Multi-output fusion kernels share a common parallel loop. The loop
168   // dimensions are determined by instruction shapes.
169   auto get_loop_shape = [&](const HloInstruction* element_instr) {
170     // Special-case reduction-to-vector ops: The loop dimensions are determined
171     // by the shape of the first operand.
172     if (IsReductionFromOrToContiguousDimensions(*element_instr)) {
173       return element_instr->operand(0)->shape();
174     }
175     return element_instr->shape();
176   };
177 
178   // All shapes of the root tuple of multi-output fusions should agree, i.e. all
179   // root ops should have equal output shapes. An exception are
180   // reduction-to-vector ops. Here the input shapes of the reduction (first
181   // operand shape) and the reduction dimensions need to match.
182   auto* instr_1 = GetRealHeroForMultiOutputFusion(instr1);
183   auto* instr_2 = GetRealHeroForMultiOutputFusion(instr2);
184   if (IsReductionFromOrToContiguousDimensions(*instr_1) &&
185       IsReductionFromOrToContiguousDimensions(*instr_2) &&
186       !AreFusedReductionOutputsConsistent({instr_1, instr_2}, instr_1)) {
187     return false;
188   }
189   // The elementwise output shapes must be the same (including layout).
190   return ShapeUtil::EqualIgnoringElementType(get_loop_shape(instr_1),
191                                              get_loop_shape(instr_2));
192 }
193 
IsInputFusibleScatter(const HloInstruction & instr)194 bool IsInputFusibleScatter(const HloInstruction& instr) {
195   if (instr.opcode() == HloOpcode::kScatter ||
196       (instr.opcode() == HloOpcode::kFusion &&
197        instr.fusion_kind() == HloInstruction::FusionKind::kInput &&
198        instr.fused_expression_root()->opcode() == HloOpcode::kScatter)) {
199     return true;
200   }
201   return false;
202 }
203 
IsInputFusible(const HloInstruction & instr)204 bool IsInputFusible(const HloInstruction& instr) {
205   // Input fusion only handles non-elemental reduction and scatter operations.
206   return instr.IsFusible() &&
207          (IsInputFusibleReduction(instr) || IsInputFusibleScatter(instr));
208 }
209 
IsLoopFusible(const HloInstruction & instr)210 bool IsLoopFusible(const HloInstruction& instr) {
211   // Don't fuse get-tuple-element on GPU: We can, but it's slower than not
212   // fusing.  We never generate kernels for unfused GTEs.  Instead, if an
213   // unfused GTE is an input to a kernel (including a fusion kernel), we
214   // compute the address of the GTE at the top of the kernel.  Often we know the
215   // address of the GTE result statically, so we can do this without chasing any
216   // pointers.
217   return instr.IsFusible() &&
218          ((instr.IsElementwise() && instr.operand_count() > 0) ||
219           instr.opcode() == HloOpcode::kBitcast ||
220           instr.opcode() == HloOpcode::kBroadcast ||
221           instr.opcode() == HloOpcode::kConcatenate ||
222           instr.opcode() == HloOpcode::kDynamicSlice ||
223           instr.opcode() == HloOpcode::kDynamicUpdateSlice ||
224           (instr.opcode() == HloOpcode::kFusion &&
225            instr.fusion_kind() == HloInstruction::FusionKind::kLoop) ||
226           instr.opcode() == HloOpcode::kGather ||
227           instr.opcode() == HloOpcode::kIota ||
228           instr.opcode() == HloOpcode::kPad ||
229           (instr.opcode() == HloOpcode::kReduce &&
230            !IsReductionFromOrToContiguousDimensions(instr) &&
231            !instr.shape().IsTuple()) ||  // TODO(b/129089333): Don't fuse
232                                          // variadic reductions.
233           instr.opcode() == HloOpcode::kReduceWindow ||
234           instr.opcode() == HloOpcode::kReshape ||
235           instr.opcode() == HloOpcode::kReverse ||
236           instr.opcode() == HloOpcode::kSlice ||
237           instr.opcode() == HloOpcode::kConstant ||
238           instr.opcode() == HloOpcode::kTranspose);
239 }
240 
IsProducerConsumerFusible(const HloInstruction & producer,const HloInstruction & consumer)241 bool IsProducerConsumerFusible(const HloInstruction& producer,
242                                const HloInstruction& consumer) {
243   if (!IsLoopFusible(producer)) {
244     VLOG(5) << "Producer " << producer.name() << " is not loop-fusible";
245     return false;
246   }
247 
248   if (!IsInputFusible(consumer) && !IsLoopFusible(consumer)) {
249     VLOG(5) << "Consumer " << consumer.name()
250             << "is not input-fusible and not loop-fusible";
251     return false;
252   }
253 
254   // Skip multiple output fusion. It's not yet supported.
255   if (producer.IsMultiOutputFusion()) {
256     VLOG(5) << "Producer " << producer.name()
257             << " is not fusible as it is a multi-output fusion";
258     return false;
259   }
260 
261   if (CreatesNestedLoop(producer, consumer)) {
262     VLOG(5) << "Fusing " << producer.name() << " into " << consumer.name()
263             << " creates nested loop";
264     return false;
265   }
266 
267   // Do not fuse into reduce input fusions if the resulting kernel would suffer
268   // from poor data locality (due to unfriendly input layouts).
269   if (IsInputFusibleReduction(consumer) &&
270       !LayoutsAreReduceInputFusionFriendly(producer, consumer)) {
271     VLOG(5) << "Layout of " << producer.name()
272             << " is not fusion-friendly for consumer reduction "
273             << consumer.name();
274     return false;
275   }
276 
277   // Fuse scalar constants into loop fusion nodes. This reduces the number of
278   // parameters and makes matching scalar broadcasts easier.
279   //
280   // Don't fuse other constants: Unfused constants in GPU land can be
281   // represented as an external constant (i.e. not emitted in LLVM IR / PTX),
282   // but fused constants are handled by shrared CPU/GPU code and always emitted
283   // in the IR/PTX.  The external constant representation makes for faster
284   // compiles and significantly smaller assembly code.
285   if (producer.opcode() == HloOpcode::kConstant &&
286       (!ShapeUtil::IsEffectiveScalar(producer.shape()) ||
287        consumer.opcode() != HloOpcode::kFusion)) {
288     VLOG(5) << "Not fusing constant " << producer.name() << " into "
289             << consumer.name();
290     return false;
291   }
292 
293   return true;
294 }
295 
IsProducerConsumerMultiOutputFusible(const HloInstruction & producer,const HloInstruction & consumer)296 bool IsProducerConsumerMultiOutputFusible(const HloInstruction& producer,
297                                           const HloInstruction& consumer) {
298   // Skip multiple output fusion. It's not yet supported.
299   if (producer.IsMultiOutputFusion()) {
300     return false;
301   }
302 
303   if (!IsLoopFusible(producer) || !IsFusibleAsMultiOutputFusionRoot(consumer)) {
304     return false;
305   }
306   if (CreatesNestedLoop(producer, consumer)) {
307     return false;
308   }
309   if (!ShapesCompatibleForMultiOutputFusion(producer, consumer)) {
310     return false;
311   }
312   if (!LayoutsAreReduceInputFusionFriendly(producer, consumer)) {
313     return false;
314   }
315   return true;
316 }
317 
318 // Returns shared memory usage for a given instruction in bytes.
SharedMemoryUsage(const HloInstruction & instr)319 static int64 SharedMemoryUsage(const HloInstruction& instr) {
320   // For now we are only fusing reductions.
321   if (IsReductionFromOrToContiguousDimensions(instr)) {
322     ReductionDimensions reduction_info =
323         GetReductionKindAndContiguousComponents(instr);
324     int64_t primitive_size =
325         ShapeUtil::ByteSizeOfPrimitiveType(instr.shape().element_type());
326     if (reduction_info.is_row_reduction) {
327       // __shared__[32] is used for row reduction.
328       return 32 * primitive_size;
329     } else {
330       // __shared__[2][32][33] cache is used for column reduction ("2" comes
331       // from potential x-tiling).
332       return 2 * 32 * 33 * primitive_size;
333     }
334   } else if (instr.opcode() == HloOpcode::kFusion) {
335     int64_t sum = 0;
336     for (const HloInstruction* hlo :
337          instr.fused_instructions_computation()->MakeInstructionPostOrder()) {
338       sum += SharedMemoryUsage(*hlo);
339     }
340     return sum;
341   }
342   // Other fused expressions for now don't need the shared memory budget.
343   return 0;
344 }
345 
346 // Codegen'ing unnested reductions requires a lot of registers, so a MOF
347 // combining many of those runs a high risk of spilling.
348 constexpr int64_t kMaxUnnestedReductionOutputsPerFusion = 8;
349 
350 // Returns the number of unnested reductions in the instruction output.
NumUnnestedReductions(const HloInstruction & instr)351 static int64 NumUnnestedReductions(const HloInstruction& instr) {
352   if (IsReductionFromOrToContiguousDimensions(instr)) {
353     return 1;
354   }
355   if (instr.opcode() == HloOpcode::kFusion) {
356     int64_t sum = 0;
357     for (const HloInstruction* hlo :
358          instr.fused_instructions_computation()->MakeInstructionPostOrder()) {
359       sum += NumUnnestedReductions(*hlo);
360     }
361     return sum;
362   }
363   return 0;
364 }
365 
366 // This function limits the maximum number of operands to a fusion, and the
367 // amount of shared memory which can be consumed by the fusion.
368 //
369 // There's a cap on how many parameters we can pass to a CUDA kernel, but
370 // exactly what that limit is hazy, as it depends on (among other things) how
371 // much GPU constant memory is in use for other purposes.
372 //
373 // Moreover, we don't even know at the point that we're running fusion how many
374 // arguments the CUDA kernel for a fusion node will have: It depends on buffer
375 // assignment, where we will decide which of the fusion's operands live in XLA's
376 // big temp buffer versus in other allocations.
377 //
378 // As a heuristic, we simply cap the number of fusion operands plus outputs at
379 // kMaxOperandsAndOutputsPerFusion.  This puts an upper bound on the number of
380 // parameters to the kernel, working around the correctness problem.
381 //
382 // This limit is also often good for performance.  In a fusion with many
383 // operands, each GPU thread likely has to do a lot of work, and so possibly
384 // uses a lot of registers, thus limiting occupancy.
385 //
386 // If the fusion is a producer/consumer fusion and instr1 is the
387 // consumer and instr2 is the producer, set is_consumer_producer_fusion
388 // to true to enable more fusion.
FusionWouldBeTooLarge(const HloInstruction & instr1,const HloInstruction & instr2,bool is_consumer_producer_fusion)389 bool FusionWouldBeTooLarge(const HloInstruction& instr1,
390                            const HloInstruction& instr2,
391                            bool is_consumer_producer_fusion) {
392   if (SharedMemoryUsage(instr1) + SharedMemoryUsage(instr2) >
393       kSharedMemoryBudgetInBytes) {
394     VLOG(5) << "Shared memory usage of fusion of " << instr1.ToString()
395             << " and " << instr2.ToString() << " would be over the budget of "
396             << kSharedMemoryBudgetInBytes << "B";
397     return true;
398   }
399 
400   if (NumUnnestedReductions(instr1) + NumUnnestedReductions(instr2) >
401       kMaxUnnestedReductionOutputsPerFusion) {
402     VLOG(5) << "Not fusing over " << kMaxUnnestedReductionOutputsPerFusion
403             << " unnested reductions in fusion";
404     return true;
405   }
406 
407   // Compute the number of outputs of the (possibly multi-output) fusion node
408   // we're considering creating.
409   //
410   // This isn't precise; we may be off by one if
411   //  - We're creating a multi-output fusion out of two non-MOFs.  Creating a
412   //    MOF adds a new buffer, namely, the tuple buffer.
413   //  - We're merging two MOFs.  In this case, we should count the tuple buffer
414   //    only once.
415   //  - WLOG there's an edge from `a` to `b` and `b` is the only consumer of
416   //    `a`.  In this case the result of `a` is not part of the output of the
417   //    fusion.
418   //
419   // But because this is a heuristic and our limit
420   // kMaxOperandsAndOutputsPerFusion is a large value (so +/- 1 doesn't make a
421   // big difference), we ignore this small inaccuracy in favor of simplicity.
422   int64_t num_output_buffers = ShapeUtil::SubshapeCount(instr1.shape()) +
423                                ShapeUtil::SubshapeCount(instr2.shape());
424 
425   // The new fusion will have no more operands and outputs than
426   //   producer_operands + consumer_operands - 1 + num_output_buffers
427   // (minus one because we may be fusing a producer->consumer edge between `a`
428   // and `b`).
429   //
430   // This fact may be enough to let us avoid having to compute the true total
431   // number of operands, which can be expensive.
432   if (instr1.operand_count() + instr2.operand_count() - 1 +
433           num_output_buffers <=
434       kMaxOperandsAndOutputsPerFusion) {
435     return false;
436   } else {
437     VLOG(5) << "Operand count of "
438             << "(" << instr1.ToString() << " ) = " << instr1.operand_count()
439             << " and ( " << instr2.ToString()
440             << " ) = " << instr2.operand_count()
441             << " and num_output_buffers = " << num_output_buffers
442             << " is bigger than the bound of "
443             << kMaxOperandsAndOutputsPerFusion;
444   }
445 
446   // Compute the precise number of operands to the new fusion.
447   absl::flat_hash_set<const HloInstruction*> operands(instr1.operands().begin(),
448                                                       instr1.operands().end());
449   operands.insert(instr2.operands().begin(), instr2.operands().end());
450   // If there's an edge between `a` and `b`, don't count it: We're fusing that
451   // producer -> consumer relationship.
452   operands.erase(&instr1);
453   operands.erase(&instr2);
454 
455   // If we generate the same numbers of inputs and outputs as
456   // before, it won't be bigger after fusion. So accept the fusion.
457   // As this is a consumer_producer fusion, this does not change the
458   // consumer numbers of output. So no need to check it.
459   if (is_consumer_producer_fusion &&
460       operands.size() <= instr1.operands().size()) {
461     return false;
462   }
463 
464   // Does the new fusion have more operands and outputs than the max?
465   return operands.size() + num_output_buffers > kMaxOperandsAndOutputsPerFusion;
466 }
467 
CreatesNestedLoop(const HloInstruction & producer,const HloInstruction & consumer)468 bool CreatesNestedLoop(const HloInstruction& producer,
469                        const HloInstruction& consumer) {
470   // If producer does not have an instruction that codegens a loop then there is
471   // nothing to do.
472   auto producer_has_loop_codegen = [&](const HloInstruction& instr) {
473     if (producer.opcode() != HloOpcode::kFusion) {
474       return IfFusedReadsElementsMultipleTimes(producer);
475     }
476     for (const auto& instr : producer.fused_instructions()) {
477       if (IfFusedReadsElementsMultipleTimes(*instr)) {
478         return true;
479       }
480     }
481     return false;
482   };
483   if (!producer_has_loop_codegen(producer)) {
484     return false;
485   }
486 
487   // If consumer is a non-fusion instruction then we have to check if it
488   // generates a loop.
489   if (consumer.opcode() != HloOpcode::kFusion) {
490     return IfFusedReadsElementsMultipleTimes(consumer);
491   }
492 
493   // If consumer is a fusion then we have to check if the output of producer is
494   // used directly or indirectly as an input to an HLO instruction that
495   // generates a loop, i.e. there is a path in the graph from an operand
496   // corresponding to the producer to an HLO instruction generating a loop in
497   // the consumer.
498   for (const HloInstruction* operand : consumer.operands()) {
499     if (operand != &producer) {
500       continue;
501     }
502 
503     const HloInstruction* root =
504         consumer.fused_instructions_computation()->parameter_instruction(
505             consumer.operand_index(operand));
506 
507     std::stack<const HloInstruction*> dfs;
508     dfs.push(root);
509     absl::flat_hash_set<const HloInstruction*> visited;
510     while (!dfs.empty()) {
511       const HloInstruction* cur = dfs.top();
512       dfs.pop();
513 
514       if (visited.contains(cur)) {
515         continue;
516       }
517       visited.insert(cur);
518 
519       if (IfFusedReadsElementsMultipleTimes(*cur)) {
520         return true;
521       }
522       for (const auto& user : cur->users()) {
523         if (visited.contains(user)) {
524           continue;
525         }
526         dfs.push(user);
527       }
528     }
529   }
530   return false;
531 }
532 
IsFusibleAsMultiOutputFusionRoot(const HloInstruction & instr)533 bool IsFusibleAsMultiOutputFusionRoot(const HloInstruction& instr) {
534   // We can fuse reduces and loop fusions. Elementwise instructions can be fused
535   // with any other instruction.
536   // Note that scatter cannot be the root of a multi-output fusion because
537   // its emitter doesn't support it.
538 
539   return instr.IsFusible() &&
540          (IsInputFusibleReduction(instr) ||
541           instr.IsLoopFusion() ||  // TODO(b/130013493): Use IsLoopFusible here.
542           instr.IsElementwise());
543 }
544 
ChooseFusionKind(const HloInstruction &,const HloInstruction & consumer)545 HloInstruction::FusionKind ChooseFusionKind(const HloInstruction& /*producer*/,
546                                             const HloInstruction& consumer) {
547   return IsInputFusible(consumer) ? HloInstruction::FusionKind::kInput
548                                   : HloInstruction::FusionKind::kLoop;
549 }
550 
IsConsumerTheOnlyNonRootUser(const HloInstruction & instr,const HloInstruction & consumer)551 bool IsConsumerTheOnlyNonRootUser(const HloInstruction& instr,
552                                   const HloInstruction& consumer) {
553   return absl::c_all_of(instr.users(), [&](const HloInstruction* user) {
554     if (user->opcode() == HloOpcode::kGetTupleElement) {
555       // Skip GTE.
556       return IsConsumerTheOnlyNonRootUser(*user, consumer);
557     }
558     if (user == &consumer) {
559       // `user` is `consumer`.
560       return true;
561     }
562     if (user == user->parent()->root_instruction()) {
563       // Consumed by ROOT.
564       return true;
565     }
566     return false;
567   });
568 }
569 
GetInstrCountOfFusible(const HloInstruction & instr)570 size_t GetInstrCountOfFusible(const HloInstruction& instr) {
571   if (instr.opcode() != HloOpcode::kFusion) {
572     return 1;
573   } else {
574     return instr.fused_instruction_count();
575   }
576 }
577 
GetOutputsOfFusible(const HloInstruction & instr)578 absl::InlinedVector<const HloInstruction*, 2> GetOutputsOfFusible(
579     const HloInstruction& instr) {
580   if (instr.opcode() != HloOpcode::kFusion) {
581     return {&instr};
582   }
583 
584   HloInstruction* root = instr.fused_expression_root();
585   if (root->opcode() != HloOpcode::kTuple) {
586     return {root};
587   } else {
588     auto v = root->operands();
589     return absl::InlinedVector<const HloInstruction*, 2>(v.begin(), v.end());
590   }
591 }
592 
GetOutputSizeOfFusible(const HloInstruction & instr)593 size_t GetOutputSizeOfFusible(const HloInstruction& instr) {
594   if (!instr.IsMultiOutputFusion()) {
595     return 1;
596   }
597   const HloInstruction* root = instr.fused_expression_root();
598   return ShapeUtil::TupleElementCount(root->shape());
599 }
600 
601 }  // namespace gpu
602 }  // namespace xla
603