1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
17
18 #include <algorithm>
19 #include <iterator>
20 #include <stack>
21 #include <vector>
22
23 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/shape.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29
30 namespace xla {
31 namespace gpu {
32 namespace {
33
34 // The amount of shared memory a CUDA kernel can use.
35 //
36 // Stay on the conservative side, this is smaller than full 64kB, but allows
37 // some extra space for cache.
38 int64_t kSharedMemoryBudgetInBytes = 40000;
39
AppendParams(const HloInstruction & instr,std::vector<HloInstruction * > * params)40 void AppendParams(const HloInstruction& instr,
41 std::vector<HloInstruction*>* params) {
42 if (instr.opcode() == HloOpcode::kFusion) {
43 params->insert(std::end(*params), std::begin(instr.fused_parameters()),
44 std::end(instr.fused_parameters()));
45 } else {
46 for (HloInstruction* operand : instr.operands()) {
47 params->push_back(operand);
48 }
49 }
50 }
51
IfFusedReadsElementsMultipleTimes(const HloInstruction & instr)52 bool IfFusedReadsElementsMultipleTimes(const HloInstruction& instr) {
53 CHECK_NE(instr.opcode(), HloOpcode::kFusion) << "`instr` has to be unfused.";
54 if (instr.opcode() == HloOpcode::kReduce &&
55 !IsReductionFromOrToContiguousDimensions(instr)) {
56 return true;
57 }
58 // Avoid fusing reduce-window when stride is less than window size to minimize
59 // the number of reads of the same elements.
60 if (instr.opcode() == HloOpcode::kReduceWindow) {
61 for (const auto& dim : instr.window().dimensions()) {
62 if (dim.size() > dim.stride()) {
63 return true;
64 }
65 }
66 }
67 return false;
68 }
69
ExtractRelativeOrderOfNontrivialDims(const Shape & shape)70 std::vector<int64> ExtractRelativeOrderOfNontrivialDims(const Shape& shape) {
71 std::vector<int64> relative_order;
72 for (int64_t dim : LayoutUtil::MinorToMajor(shape)) {
73 if (shape.dimensions(dim) > 1) {
74 relative_order.push_back(dim);
75 }
76 }
77 // Now normalize the dimensions to values between 0 and true rank - 1.
78 std::vector<int64> sorted_dims = relative_order;
79 std::sort(sorted_dims.begin(), sorted_dims.end());
80 for (int64& dim : relative_order) {
81 int64_t sorted_index = std::distance(
82 sorted_dims.begin(),
83 std::lower_bound(sorted_dims.begin(), sorted_dims.end(), dim));
84 dim = sorted_index;
85 }
86 return relative_order;
87 }
88
89 } // namespace
90
LayoutsAreReduceInputFusionFriendly(const HloInstruction & producer,const HloInstruction & reduce)91 bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer,
92 const HloInstruction& reduce) {
93 std::vector<HloInstruction*> params;
94 AppendParams(producer, ¶ms);
95 AppendParams(reduce, ¶ms);
96 int64_t max_true_rank = -1;
97 std::vector<int64> max_rank_order;
98 for (HloInstruction* param : params) {
99 if (param->shape().IsArray() &&
100 ShapeUtil::TrueRank(param->shape()) > max_true_rank) {
101 max_true_rank = ShapeUtil::TrueRank(param->shape());
102 max_rank_order = ExtractRelativeOrderOfNontrivialDims(param->shape());
103 }
104 }
105 return absl::c_all_of(params, [&](HloInstruction* param) {
106 return !param->shape().IsArray() ||
107 ShapeUtil::TrueRank(param->shape()) < max_true_rank ||
108 ExtractRelativeOrderOfNontrivialDims(param->shape()) ==
109 max_rank_order;
110 });
111 }
112
IsReduceInputFusion(const HloInstruction & instr)113 bool IsReduceInputFusion(const HloInstruction& instr) {
114 if (instr.IsMultiOutputFusion()) {
115 for (const HloInstruction* operand :
116 instr.fused_expression_root()->operands()) {
117 if (IsReductionFromOrToContiguousDimensions(*operand)) {
118 CHECK(instr.IsInputFusion())
119 << " Multi-output fusion rooted at reduction-to-vector ops must be "
120 "of kind kInput: "
121 << instr.ToString();
122 return true;
123 }
124 }
125 } else if (instr.opcode() == HloOpcode::kFusion &&
126 IsReductionFromOrToContiguousDimensions(
127 *instr.fused_expression_root())) {
128 CHECK(instr.IsInputFusion())
129 << " Fusion rooted at reduction-to-vector op must be of kind kInput: "
130 << instr.ToString();
131 return true;
132 }
133 return false;
134 }
135
IsInputFusibleReduction(const HloInstruction & instr)136 bool IsInputFusibleReduction(const HloInstruction& instr) {
137 // TODO(b/129089333): Don't fuse variadic reduce.
138 if (instr.opcode() == HloOpcode::kReduce && instr.shape().IsTuple()) {
139 return false;
140 }
141
142 return IsReduceInputFusion(instr) ||
143 IsReductionFromOrToContiguousDimensions(instr);
144 }
145
GetRealHeroForMultiOutputFusion(const HloInstruction & instr)146 const HloInstruction* GetRealHeroForMultiOutputFusion(
147 const HloInstruction& instr) {
148 if (instr.opcode() != HloOpcode::kFusion) {
149 return &instr;
150 }
151 auto fused_expression_root = instr.fused_expression_root();
152 if (!instr.IsMultiOutputFusion()) {
153 return fused_expression_root;
154 }
155 // If possible, we want to pick a reduction-from-or-to-contiguous-dims
156 // operand of the fusion root, because it has the most constraints.
157 for (const auto* inst : fused_expression_root->operands()) {
158 if (IsReductionFromOrToContiguousDimensions(*inst)) {
159 return inst;
160 }
161 }
162 return fused_expression_root->operands()[0];
163 }
164
ShapesCompatibleForMultiOutputFusion(const HloInstruction & instr1,const HloInstruction & instr2)165 bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1,
166 const HloInstruction& instr2) {
167 // Multi-output fusion kernels share a common parallel loop. The loop
168 // dimensions are determined by instruction shapes.
169 auto get_loop_shape = [&](const HloInstruction* element_instr) {
170 // Special-case reduction-to-vector ops: The loop dimensions are determined
171 // by the shape of the first operand.
172 if (IsReductionFromOrToContiguousDimensions(*element_instr)) {
173 return element_instr->operand(0)->shape();
174 }
175 return element_instr->shape();
176 };
177
178 // All shapes of the root tuple of multi-output fusions should agree, i.e. all
179 // root ops should have equal output shapes. An exception are
180 // reduction-to-vector ops. Here the input shapes of the reduction (first
181 // operand shape) and the reduction dimensions need to match.
182 auto* instr_1 = GetRealHeroForMultiOutputFusion(instr1);
183 auto* instr_2 = GetRealHeroForMultiOutputFusion(instr2);
184 if (IsReductionFromOrToContiguousDimensions(*instr_1) &&
185 IsReductionFromOrToContiguousDimensions(*instr_2) &&
186 !AreFusedReductionOutputsConsistent({instr_1, instr_2}, instr_1)) {
187 return false;
188 }
189 // The elementwise output shapes must be the same (including layout).
190 return ShapeUtil::EqualIgnoringElementType(get_loop_shape(instr_1),
191 get_loop_shape(instr_2));
192 }
193
IsInputFusibleScatter(const HloInstruction & instr)194 bool IsInputFusibleScatter(const HloInstruction& instr) {
195 if (instr.opcode() == HloOpcode::kScatter ||
196 (instr.opcode() == HloOpcode::kFusion &&
197 instr.fusion_kind() == HloInstruction::FusionKind::kInput &&
198 instr.fused_expression_root()->opcode() == HloOpcode::kScatter)) {
199 return true;
200 }
201 return false;
202 }
203
IsInputFusible(const HloInstruction & instr)204 bool IsInputFusible(const HloInstruction& instr) {
205 // Input fusion only handles non-elemental reduction and scatter operations.
206 return instr.IsFusible() &&
207 (IsInputFusibleReduction(instr) || IsInputFusibleScatter(instr));
208 }
209
IsLoopFusible(const HloInstruction & instr)210 bool IsLoopFusible(const HloInstruction& instr) {
211 // Don't fuse get-tuple-element on GPU: We can, but it's slower than not
212 // fusing. We never generate kernels for unfused GTEs. Instead, if an
213 // unfused GTE is an input to a kernel (including a fusion kernel), we
214 // compute the address of the GTE at the top of the kernel. Often we know the
215 // address of the GTE result statically, so we can do this without chasing any
216 // pointers.
217 return instr.IsFusible() &&
218 ((instr.IsElementwise() && instr.operand_count() > 0) ||
219 instr.opcode() == HloOpcode::kBitcast ||
220 instr.opcode() == HloOpcode::kBroadcast ||
221 instr.opcode() == HloOpcode::kConcatenate ||
222 instr.opcode() == HloOpcode::kDynamicSlice ||
223 instr.opcode() == HloOpcode::kDynamicUpdateSlice ||
224 (instr.opcode() == HloOpcode::kFusion &&
225 instr.fusion_kind() == HloInstruction::FusionKind::kLoop) ||
226 instr.opcode() == HloOpcode::kGather ||
227 instr.opcode() == HloOpcode::kIota ||
228 instr.opcode() == HloOpcode::kPad ||
229 (instr.opcode() == HloOpcode::kReduce &&
230 !IsReductionFromOrToContiguousDimensions(instr) &&
231 !instr.shape().IsTuple()) || // TODO(b/129089333): Don't fuse
232 // variadic reductions.
233 instr.opcode() == HloOpcode::kReduceWindow ||
234 instr.opcode() == HloOpcode::kReshape ||
235 instr.opcode() == HloOpcode::kReverse ||
236 instr.opcode() == HloOpcode::kSlice ||
237 instr.opcode() == HloOpcode::kConstant ||
238 instr.opcode() == HloOpcode::kTranspose);
239 }
240
IsProducerConsumerFusible(const HloInstruction & producer,const HloInstruction & consumer)241 bool IsProducerConsumerFusible(const HloInstruction& producer,
242 const HloInstruction& consumer) {
243 if (!IsLoopFusible(producer)) {
244 VLOG(5) << "Producer " << producer.name() << " is not loop-fusible";
245 return false;
246 }
247
248 if (!IsInputFusible(consumer) && !IsLoopFusible(consumer)) {
249 VLOG(5) << "Consumer " << consumer.name()
250 << "is not input-fusible and not loop-fusible";
251 return false;
252 }
253
254 // Skip multiple output fusion. It's not yet supported.
255 if (producer.IsMultiOutputFusion()) {
256 VLOG(5) << "Producer " << producer.name()
257 << " is not fusible as it is a multi-output fusion";
258 return false;
259 }
260
261 if (CreatesNestedLoop(producer, consumer)) {
262 VLOG(5) << "Fusing " << producer.name() << " into " << consumer.name()
263 << " creates nested loop";
264 return false;
265 }
266
267 // Do not fuse into reduce input fusions if the resulting kernel would suffer
268 // from poor data locality (due to unfriendly input layouts).
269 if (IsInputFusibleReduction(consumer) &&
270 !LayoutsAreReduceInputFusionFriendly(producer, consumer)) {
271 VLOG(5) << "Layout of " << producer.name()
272 << " is not fusion-friendly for consumer reduction "
273 << consumer.name();
274 return false;
275 }
276
277 // Fuse scalar constants into loop fusion nodes. This reduces the number of
278 // parameters and makes matching scalar broadcasts easier.
279 //
280 // Don't fuse other constants: Unfused constants in GPU land can be
281 // represented as an external constant (i.e. not emitted in LLVM IR / PTX),
282 // but fused constants are handled by shrared CPU/GPU code and always emitted
283 // in the IR/PTX. The external constant representation makes for faster
284 // compiles and significantly smaller assembly code.
285 if (producer.opcode() == HloOpcode::kConstant &&
286 (!ShapeUtil::IsEffectiveScalar(producer.shape()) ||
287 consumer.opcode() != HloOpcode::kFusion)) {
288 VLOG(5) << "Not fusing constant " << producer.name() << " into "
289 << consumer.name();
290 return false;
291 }
292
293 return true;
294 }
295
IsProducerConsumerMultiOutputFusible(const HloInstruction & producer,const HloInstruction & consumer)296 bool IsProducerConsumerMultiOutputFusible(const HloInstruction& producer,
297 const HloInstruction& consumer) {
298 // Skip multiple output fusion. It's not yet supported.
299 if (producer.IsMultiOutputFusion()) {
300 return false;
301 }
302
303 if (!IsLoopFusible(producer) || !IsFusibleAsMultiOutputFusionRoot(consumer)) {
304 return false;
305 }
306 if (CreatesNestedLoop(producer, consumer)) {
307 return false;
308 }
309 if (!ShapesCompatibleForMultiOutputFusion(producer, consumer)) {
310 return false;
311 }
312 if (!LayoutsAreReduceInputFusionFriendly(producer, consumer)) {
313 return false;
314 }
315 return true;
316 }
317
318 // Returns shared memory usage for a given instruction in bytes.
SharedMemoryUsage(const HloInstruction & instr)319 static int64 SharedMemoryUsage(const HloInstruction& instr) {
320 // For now we are only fusing reductions.
321 if (IsReductionFromOrToContiguousDimensions(instr)) {
322 ReductionDimensions reduction_info =
323 GetReductionKindAndContiguousComponents(instr);
324 int64_t primitive_size =
325 ShapeUtil::ByteSizeOfPrimitiveType(instr.shape().element_type());
326 if (reduction_info.is_row_reduction) {
327 // __shared__[32] is used for row reduction.
328 return 32 * primitive_size;
329 } else {
330 // __shared__[2][32][33] cache is used for column reduction ("2" comes
331 // from potential x-tiling).
332 return 2 * 32 * 33 * primitive_size;
333 }
334 } else if (instr.opcode() == HloOpcode::kFusion) {
335 int64_t sum = 0;
336 for (const HloInstruction* hlo :
337 instr.fused_instructions_computation()->MakeInstructionPostOrder()) {
338 sum += SharedMemoryUsage(*hlo);
339 }
340 return sum;
341 }
342 // Other fused expressions for now don't need the shared memory budget.
343 return 0;
344 }
345
346 // Codegen'ing unnested reductions requires a lot of registers, so a MOF
347 // combining many of those runs a high risk of spilling.
348 constexpr int64_t kMaxUnnestedReductionOutputsPerFusion = 8;
349
350 // Returns the number of unnested reductions in the instruction output.
NumUnnestedReductions(const HloInstruction & instr)351 static int64 NumUnnestedReductions(const HloInstruction& instr) {
352 if (IsReductionFromOrToContiguousDimensions(instr)) {
353 return 1;
354 }
355 if (instr.opcode() == HloOpcode::kFusion) {
356 int64_t sum = 0;
357 for (const HloInstruction* hlo :
358 instr.fused_instructions_computation()->MakeInstructionPostOrder()) {
359 sum += NumUnnestedReductions(*hlo);
360 }
361 return sum;
362 }
363 return 0;
364 }
365
366 // This function limits the maximum number of operands to a fusion, and the
367 // amount of shared memory which can be consumed by the fusion.
368 //
369 // There's a cap on how many parameters we can pass to a CUDA kernel, but
370 // exactly what that limit is hazy, as it depends on (among other things) how
371 // much GPU constant memory is in use for other purposes.
372 //
373 // Moreover, we don't even know at the point that we're running fusion how many
374 // arguments the CUDA kernel for a fusion node will have: It depends on buffer
375 // assignment, where we will decide which of the fusion's operands live in XLA's
376 // big temp buffer versus in other allocations.
377 //
378 // As a heuristic, we simply cap the number of fusion operands plus outputs at
379 // kMaxOperandsAndOutputsPerFusion. This puts an upper bound on the number of
380 // parameters to the kernel, working around the correctness problem.
381 //
382 // This limit is also often good for performance. In a fusion with many
383 // operands, each GPU thread likely has to do a lot of work, and so possibly
384 // uses a lot of registers, thus limiting occupancy.
385 //
386 // If the fusion is a producer/consumer fusion and instr1 is the
387 // consumer and instr2 is the producer, set is_consumer_producer_fusion
388 // to true to enable more fusion.
FusionWouldBeTooLarge(const HloInstruction & instr1,const HloInstruction & instr2,bool is_consumer_producer_fusion)389 bool FusionWouldBeTooLarge(const HloInstruction& instr1,
390 const HloInstruction& instr2,
391 bool is_consumer_producer_fusion) {
392 if (SharedMemoryUsage(instr1) + SharedMemoryUsage(instr2) >
393 kSharedMemoryBudgetInBytes) {
394 VLOG(5) << "Shared memory usage of fusion of " << instr1.ToString()
395 << " and " << instr2.ToString() << " would be over the budget of "
396 << kSharedMemoryBudgetInBytes << "B";
397 return true;
398 }
399
400 if (NumUnnestedReductions(instr1) + NumUnnestedReductions(instr2) >
401 kMaxUnnestedReductionOutputsPerFusion) {
402 VLOG(5) << "Not fusing over " << kMaxUnnestedReductionOutputsPerFusion
403 << " unnested reductions in fusion";
404 return true;
405 }
406
407 // Compute the number of outputs of the (possibly multi-output) fusion node
408 // we're considering creating.
409 //
410 // This isn't precise; we may be off by one if
411 // - We're creating a multi-output fusion out of two non-MOFs. Creating a
412 // MOF adds a new buffer, namely, the tuple buffer.
413 // - We're merging two MOFs. In this case, we should count the tuple buffer
414 // only once.
415 // - WLOG there's an edge from `a` to `b` and `b` is the only consumer of
416 // `a`. In this case the result of `a` is not part of the output of the
417 // fusion.
418 //
419 // But because this is a heuristic and our limit
420 // kMaxOperandsAndOutputsPerFusion is a large value (so +/- 1 doesn't make a
421 // big difference), we ignore this small inaccuracy in favor of simplicity.
422 int64_t num_output_buffers = ShapeUtil::SubshapeCount(instr1.shape()) +
423 ShapeUtil::SubshapeCount(instr2.shape());
424
425 // The new fusion will have no more operands and outputs than
426 // producer_operands + consumer_operands - 1 + num_output_buffers
427 // (minus one because we may be fusing a producer->consumer edge between `a`
428 // and `b`).
429 //
430 // This fact may be enough to let us avoid having to compute the true total
431 // number of operands, which can be expensive.
432 if (instr1.operand_count() + instr2.operand_count() - 1 +
433 num_output_buffers <=
434 kMaxOperandsAndOutputsPerFusion) {
435 return false;
436 } else {
437 VLOG(5) << "Operand count of "
438 << "(" << instr1.ToString() << " ) = " << instr1.operand_count()
439 << " and ( " << instr2.ToString()
440 << " ) = " << instr2.operand_count()
441 << " and num_output_buffers = " << num_output_buffers
442 << " is bigger than the bound of "
443 << kMaxOperandsAndOutputsPerFusion;
444 }
445
446 // Compute the precise number of operands to the new fusion.
447 absl::flat_hash_set<const HloInstruction*> operands(instr1.operands().begin(),
448 instr1.operands().end());
449 operands.insert(instr2.operands().begin(), instr2.operands().end());
450 // If there's an edge between `a` and `b`, don't count it: We're fusing that
451 // producer -> consumer relationship.
452 operands.erase(&instr1);
453 operands.erase(&instr2);
454
455 // If we generate the same numbers of inputs and outputs as
456 // before, it won't be bigger after fusion. So accept the fusion.
457 // As this is a consumer_producer fusion, this does not change the
458 // consumer numbers of output. So no need to check it.
459 if (is_consumer_producer_fusion &&
460 operands.size() <= instr1.operands().size()) {
461 return false;
462 }
463
464 // Does the new fusion have more operands and outputs than the max?
465 return operands.size() + num_output_buffers > kMaxOperandsAndOutputsPerFusion;
466 }
467
CreatesNestedLoop(const HloInstruction & producer,const HloInstruction & consumer)468 bool CreatesNestedLoop(const HloInstruction& producer,
469 const HloInstruction& consumer) {
470 // If producer does not have an instruction that codegens a loop then there is
471 // nothing to do.
472 auto producer_has_loop_codegen = [&](const HloInstruction& instr) {
473 if (producer.opcode() != HloOpcode::kFusion) {
474 return IfFusedReadsElementsMultipleTimes(producer);
475 }
476 for (const auto& instr : producer.fused_instructions()) {
477 if (IfFusedReadsElementsMultipleTimes(*instr)) {
478 return true;
479 }
480 }
481 return false;
482 };
483 if (!producer_has_loop_codegen(producer)) {
484 return false;
485 }
486
487 // If consumer is a non-fusion instruction then we have to check if it
488 // generates a loop.
489 if (consumer.opcode() != HloOpcode::kFusion) {
490 return IfFusedReadsElementsMultipleTimes(consumer);
491 }
492
493 // If consumer is a fusion then we have to check if the output of producer is
494 // used directly or indirectly as an input to an HLO instruction that
495 // generates a loop, i.e. there is a path in the graph from an operand
496 // corresponding to the producer to an HLO instruction generating a loop in
497 // the consumer.
498 for (const HloInstruction* operand : consumer.operands()) {
499 if (operand != &producer) {
500 continue;
501 }
502
503 const HloInstruction* root =
504 consumer.fused_instructions_computation()->parameter_instruction(
505 consumer.operand_index(operand));
506
507 std::stack<const HloInstruction*> dfs;
508 dfs.push(root);
509 absl::flat_hash_set<const HloInstruction*> visited;
510 while (!dfs.empty()) {
511 const HloInstruction* cur = dfs.top();
512 dfs.pop();
513
514 if (visited.contains(cur)) {
515 continue;
516 }
517 visited.insert(cur);
518
519 if (IfFusedReadsElementsMultipleTimes(*cur)) {
520 return true;
521 }
522 for (const auto& user : cur->users()) {
523 if (visited.contains(user)) {
524 continue;
525 }
526 dfs.push(user);
527 }
528 }
529 }
530 return false;
531 }
532
IsFusibleAsMultiOutputFusionRoot(const HloInstruction & instr)533 bool IsFusibleAsMultiOutputFusionRoot(const HloInstruction& instr) {
534 // We can fuse reduces and loop fusions. Elementwise instructions can be fused
535 // with any other instruction.
536 // Note that scatter cannot be the root of a multi-output fusion because
537 // its emitter doesn't support it.
538
539 return instr.IsFusible() &&
540 (IsInputFusibleReduction(instr) ||
541 instr.IsLoopFusion() || // TODO(b/130013493): Use IsLoopFusible here.
542 instr.IsElementwise());
543 }
544
ChooseFusionKind(const HloInstruction &,const HloInstruction & consumer)545 HloInstruction::FusionKind ChooseFusionKind(const HloInstruction& /*producer*/,
546 const HloInstruction& consumer) {
547 return IsInputFusible(consumer) ? HloInstruction::FusionKind::kInput
548 : HloInstruction::FusionKind::kLoop;
549 }
550
IsConsumerTheOnlyNonRootUser(const HloInstruction & instr,const HloInstruction & consumer)551 bool IsConsumerTheOnlyNonRootUser(const HloInstruction& instr,
552 const HloInstruction& consumer) {
553 return absl::c_all_of(instr.users(), [&](const HloInstruction* user) {
554 if (user->opcode() == HloOpcode::kGetTupleElement) {
555 // Skip GTE.
556 return IsConsumerTheOnlyNonRootUser(*user, consumer);
557 }
558 if (user == &consumer) {
559 // `user` is `consumer`.
560 return true;
561 }
562 if (user == user->parent()->root_instruction()) {
563 // Consumed by ROOT.
564 return true;
565 }
566 return false;
567 });
568 }
569
GetInstrCountOfFusible(const HloInstruction & instr)570 size_t GetInstrCountOfFusible(const HloInstruction& instr) {
571 if (instr.opcode() != HloOpcode::kFusion) {
572 return 1;
573 } else {
574 return instr.fused_instruction_count();
575 }
576 }
577
GetOutputsOfFusible(const HloInstruction & instr)578 absl::InlinedVector<const HloInstruction*, 2> GetOutputsOfFusible(
579 const HloInstruction& instr) {
580 if (instr.opcode() != HloOpcode::kFusion) {
581 return {&instr};
582 }
583
584 HloInstruction* root = instr.fused_expression_root();
585 if (root->opcode() != HloOpcode::kTuple) {
586 return {root};
587 } else {
588 auto v = root->operands();
589 return absl::InlinedVector<const HloInstruction*, 2>(v.begin(), v.end());
590 }
591 }
592
GetOutputSizeOfFusible(const HloInstruction & instr)593 size_t GetOutputSizeOfFusible(const HloInstruction& instr) {
594 if (!instr.IsMultiOutputFusion()) {
595 return 1;
596 }
597 const HloInstruction* root = instr.fused_expression_root();
598 return ShapeUtil::TupleElementCount(root->shape());
599 }
600
601 } // namespace gpu
602 } // namespace xla
603