1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
17
18 #include <algorithm>
19 #include <iterator>
20 #include <stack>
21 #include <vector>
22
23 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/service/instruction_fusion.h"
28 #include "tensorflow/compiler/xla/shape.h"
29 #include "tensorflow/compiler/xla/shape_util.h"
30
31 namespace xla {
32 namespace gpu {
33 namespace {
34
35 // The amount of shared memory a CUDA kernel can use.
36 //
37 // Stay on the conservative side, this is smaller than full 64kB, but allows
38 // some extra space for cache.
39 int64_t kSharedMemoryBudgetInBytes = 40000;
40
IfFusedReadsElementsMultipleTimes(const HloInstruction & instr)41 bool IfFusedReadsElementsMultipleTimes(const HloInstruction& instr) {
42 CHECK_NE(instr.opcode(), HloOpcode::kFusion) << "`instr` has to be unfused.";
43 if (instr.opcode() == HloOpcode::kReduce &&
44 !IsReductionFromOrToContiguousDimensions(instr)) {
45 return true;
46 }
47 // Avoid fusing reduce-window when stride is less than window size to minimize
48 // the number of reads of the same elements.
49 if (instr.opcode() == HloOpcode::kReduceWindow) {
50 for (const auto& dim : instr.window().dimensions()) {
51 if (dim.size() > dim.stride()) {
52 return true;
53 }
54 }
55 }
56 return false;
57 }
58
59 } // namespace
60
IsPhysicallyTransposing(const HloInstruction & instr)61 bool IsPhysicallyTransposing(const HloInstruction& instr) {
62 if (instr.opcode() == HloOpcode::kFusion) {
63 for (const HloInstruction* fused_instr : instr.fused_instructions()) {
64 if (IsPhysicallyTransposing(*fused_instr)) {
65 return true;
66 }
67 }
68 }
69
70 // A fusion iterates over its output in physically-contiguous order. This
71 // applies "upwards" to operands. Only an operator that changes an operand's
72 // physical layout can create a "bad" memory access pattern.
73 return instr.opcode() == HloOpcode::kCopy ||
74 (instr.opcode() == HloOpcode::kTranspose &&
75 !ShapeUtil::TransposeIsBitcast(instr.operand(0)->shape(),
76 instr.shape(), instr.dimensions()));
77 }
78
IsReduceInputFusion(const HloInstruction & instr)79 bool IsReduceInputFusion(const HloInstruction& instr) {
80 if (instr.IsMultiOutputFusion()) {
81 for (const HloInstruction* operand :
82 instr.fused_expression_root()->operands()) {
83 if (IsReductionFromOrToContiguousDimensions(*operand)) {
84 CHECK(instr.IsInputFusion())
85 << " Multi-output fusion rooted at reduction-to-vector ops must be "
86 "of kind kInput: "
87 << instr.ToString();
88 return true;
89 }
90 }
91 } else if (instr.opcode() == HloOpcode::kFusion &&
92 IsReductionFromOrToContiguousDimensions(
93 *instr.fused_expression_root())) {
94 CHECK(instr.IsInputFusion())
95 << " Fusion rooted at reduction-to-vector op must be of kind kInput: "
96 << instr.ToString();
97 return true;
98 }
99 return false;
100 }
101
IsInputFusibleReduction(const HloInstruction & instr)102 bool IsInputFusibleReduction(const HloInstruction& instr) {
103 return IsReduceInputFusion(instr) ||
104 IsReductionFromOrToContiguousDimensions(instr);
105 }
106
GetRealHeroForMultiOutputFusion(const HloInstruction & instr)107 const HloInstruction* GetRealHeroForMultiOutputFusion(
108 const HloInstruction& instr) {
109 if (instr.opcode() != HloOpcode::kFusion) {
110 return &instr;
111 }
112 auto fused_expression_root = instr.fused_expression_root();
113 if (!instr.IsMultiOutputFusion()) {
114 return fused_expression_root;
115 }
116 // If possible, we want to pick a reduction-from-or-to-contiguous-dims
117 // operand of the fusion root, because it has the most constraints.
118 for (const auto* inst : fused_expression_root->operands()) {
119 if (IsReductionFromOrToContiguousDimensions(*inst)) {
120 return inst;
121 }
122 }
123 return fused_expression_root->operands()[0];
124 }
125
ShapesCompatibleForMultiOutputFusion(const HloInstruction & instr1,const HloInstruction & instr2)126 bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1,
127 const HloInstruction& instr2) {
128 // Multi-output fusion kernels share a common parallel loop. The loop
129 // dimensions are determined by instruction shapes.
130 auto get_loop_shape = [&](const HloInstruction* element_instr) {
131 // Special-case reduction-to-vector ops: The loop dimensions are determined
132 // by the shape of the first operand.
133 if (IsReductionFromOrToContiguousDimensions(*element_instr)) {
134 return element_instr->operand(0)->shape();
135 }
136 return element_instr->shape();
137 };
138
139 // All shapes of the root tuple of multi-output fusions should agree, i.e. all
140 // root ops should have equal output shapes. An exception are
141 // reduction-to-vector ops. Here the input shapes of the reduction (first
142 // operand shape) and the reduction dimensions need to match.
143 auto* instr_1 = GetRealHeroForMultiOutputFusion(instr1);
144 auto* instr_2 = GetRealHeroForMultiOutputFusion(instr2);
145 if (IsReductionFromOrToContiguousDimensions(*instr_1) &&
146 IsReductionFromOrToContiguousDimensions(*instr_2) &&
147 !AreFusedReductionOutputsConsistent({instr_1, instr_2}, instr_1)) {
148 return false;
149 }
150 // The elementwise output shapes must be the same (including layout).
151 return ShapeUtil::EqualIgnoringElementType(get_loop_shape(instr_1),
152 get_loop_shape(instr_2));
153 }
154
IsInputFusibleScatter(const HloInstruction & instr)155 bool IsInputFusibleScatter(const HloInstruction& instr) {
156 if (instr.opcode() == HloOpcode::kScatter ||
157 (instr.opcode() == HloOpcode::kFusion &&
158 instr.fusion_kind() == HloInstruction::FusionKind::kInput &&
159 instr.fused_expression_root()->opcode() == HloOpcode::kScatter)) {
160 return true;
161 }
162 return false;
163 }
164
IsInputFusible(const HloInstruction & instr)165 bool IsInputFusible(const HloInstruction& instr) {
166 // Input fusion only handles non-elemental reduction and scatter operations.
167 return instr.IsFusible() &&
168 (IsInputFusibleReduction(instr) || IsInputFusibleScatter(instr));
169 }
170
IsLoopFusible(const HloInstruction & instr)171 bool IsLoopFusible(const HloInstruction& instr) {
172 // Don't fuse get-tuple-element on GPU: We can, but it's slower than not
173 // fusing. We never generate kernels for unfused GTEs. Instead, if an
174 // unfused GTE is an input to a kernel (including a fusion kernel), we
175 // compute the address of the GTE at the top of the kernel. Often we know the
176 // address of the GTE result statically, so we can do this without chasing any
177 // pointers.
178 return instr.IsFusible() &&
179 ((instr.IsElementwise() && instr.operand_count() > 0) ||
180 instr.opcode() == HloOpcode::kBitcast ||
181 instr.opcode() == HloOpcode::kBroadcast ||
182 instr.opcode() == HloOpcode::kConcatenate ||
183 instr.opcode() == HloOpcode::kDynamicSlice ||
184 instr.opcode() == HloOpcode::kDynamicUpdateSlice ||
185 (instr.opcode() == HloOpcode::kFusion &&
186 instr.fusion_kind() == HloInstruction::FusionKind::kLoop) ||
187 instr.opcode() == HloOpcode::kGather ||
188 instr.opcode() == HloOpcode::kIota ||
189 instr.opcode() == HloOpcode::kPad ||
190 (instr.opcode() == HloOpcode::kReduce &&
191 !IsReductionFromOrToContiguousDimensions(instr) &&
192 !instr.shape().IsTuple()) || // TODO(b/129089333): Don't fuse
193 // variadic reductions.
194 instr.opcode() == HloOpcode::kReduceWindow ||
195 instr.opcode() == HloOpcode::kReshape ||
196 instr.opcode() == HloOpcode::kReverse ||
197 instr.opcode() == HloOpcode::kSlice ||
198 instr.opcode() == HloOpcode::kConstant ||
199 instr.opcode() == HloOpcode::kTranspose);
200 }
201
IsProducerConsumerFusible(const HloInstruction & producer,const HloInstruction & consumer)202 FusionDecision IsProducerConsumerFusible(const HloInstruction& producer,
203 const HloInstruction& consumer) {
204 if (!IsLoopFusible(producer)) {
205 return "the producer is not loop-fusible";
206 }
207
208 if (!IsInputFusible(consumer) && !IsLoopFusible(consumer)) {
209 return "the consumer is not input-fusible and not loop-fusible";
210 }
211
212 // Skip multiple output fusion. It's not yet supported.
213 if (producer.IsMultiOutputFusion()) {
214 return "the producer is not fusible as it is a multi-output fusion";
215 }
216
217 if (CreatesNestedLoop(producer, consumer)) {
218 return "the fusion would create a nested loop";
219 }
220
221 // Do not fuse into fusions if the resulting kernel would suffer from
222 // uncoalesced reads due to a transposed memory access pattern.
223 if (IsInputFusibleReduction(consumer) && IsPhysicallyTransposing(producer)) {
224 return "fusing the producer would break read coalescing";
225 }
226
227 // Fuse scalar constants into loop fusion nodes. This reduces the number of
228 // parameters and makes matching scalar broadcasts easier.
229 //
230 // Don't fuse other constants: Unfused constants in GPU land can be
231 // represented as an external constant (i.e. not emitted in LLVM IR / PTX),
232 // but fused constants are handled by shrared CPU/GPU code and always emitted
233 // in the IR/PTX. The external constant representation makes for faster
234 // compiles and significantly smaller assembly code.
235 if (producer.opcode() == HloOpcode::kConstant &&
236 (!ShapeUtil::IsEffectiveScalar(producer.shape()) ||
237 consumer.opcode() != HloOpcode::kFusion)) {
238 return "not fusing constant";
239 }
240
241 // Make sure the new fusion obeys the in-place semantics.
242 return InstructionFusion::ShouldFuseInPlaceOp(&producer, &consumer);
243 }
244
IsProducerConsumerMultiOutputFusible(const HloInstruction & producer,const HloInstruction & consumer)245 bool IsProducerConsumerMultiOutputFusible(const HloInstruction& producer,
246 const HloInstruction& consumer) {
247 // Skip multiple output fusion. It's not yet supported.
248 if (producer.IsMultiOutputFusion()) {
249 return false;
250 }
251
252 // Allowing multi-output fusions that contain in-place operations makes code
253 // generation more difficult. For the generated loop to iterate over all
254 // outputs in parallel, it must find an iteration order that guarantees that
255 // no loop iteration writes an element of any in-place operand that is read
256 // or written by any other iteration. For example:
257 //
258 // %fused_computation {
259 // %param_0 = s32[4,4]{1,0} parameter(0)
260 // ...
261 // %updated = s32[4,4]{1,0} dynamic-update-slice(
262 // %param_0, %add, %constant_1, %constant_0)
263 // %transpose = s32[4,4]{0,1} transpose(%updated), dimensions={1,0}
264 // ROOT %tuple.5 = tuple(%transpose, %updated)
265 // }
266 //
267 // Iterating 'transpose' and 'updated' in parallel by array index is
268 // not valid, because an iteration that produces some element of 'transpose'
269 // will read from an element of 'param_0' that has been overwritten by some
270 // other iteration (writing to 'updated').
271 //
272 // To avoid these problems, we simply ban fusion altogether when the producer
273 // is in-place. (We can relax this restriction by establishing an explicit
274 // contract that describes what multi-output fusion scenarios are supported by
275 // codegen and then changing this check to allow exactly those fusions).
276 if (!HloDataflowAnalysis::GetInPlaceInputOutputPairs(&producer).empty()) {
277 return false;
278 }
279 if (!IsLoopFusible(producer) || !IsFusibleAsMultiOutputFusionRoot(consumer)) {
280 return false;
281 }
282 if (CreatesNestedLoop(producer, consumer)) {
283 return false;
284 }
285 if (!ShapesCompatibleForMultiOutputFusion(producer, consumer)) {
286 return false;
287 }
288 if (IsPhysicallyTransposing(producer)) {
289 return false;
290 }
291 return true;
292 }
293
294 // Returns shared memory usage for a given instruction in bytes.
SharedMemoryUsageNoCache(const HloInstruction & instr)295 static int64_t SharedMemoryUsageNoCache(const HloInstruction& instr) {
296 // For now we are only fusing reductions.
297 if (instr.opcode() == HloOpcode::kReduce &&
298 IsReductionFromOrToContiguousDimensions(instr)) {
299 ReductionDimensions reduction_info =
300 GetReductionKindAndContiguousComponents(instr);
301 int64_t primitive_size = ShapeUtil::ByteSizeOfPrimitiveType(
302 instr.operand(0)->shape().element_type());
303 int num_variadic =
304 instr.shape().IsTuple() ? instr.shape().tuple_shapes_size() : 1;
305 if (reduction_info.is_row_reduction) {
306 // __shared__[32] is used for row reduction.
307 return 32 * primitive_size * num_variadic;
308 } else {
309 // __shared__[2][32][33] cache is used for column reduction ("2" comes
310 // from potential x-tiling).
311 return 2 * 32 * 33 * primitive_size * num_variadic;
312 }
313 } else if (instr.opcode() == HloOpcode::kFusion) {
314 int64_t sum = 0;
315 for (const HloInstruction* hlo :
316 instr.fused_instructions_computation()->instructions()) {
317 sum += SharedMemoryUsageNoCache(*hlo);
318 }
319 return sum;
320 }
321 // Other fused expressions for now don't need the shared memory budget.
322 return 0;
323 }
324
SharedMemoryUsage(const HloInstruction & instr,FusionInfoCache * cache=nullptr)325 static int64_t SharedMemoryUsage(const HloInstruction& instr,
326 FusionInfoCache* cache = nullptr) {
327 if (!cache) {
328 return SharedMemoryUsageNoCache(instr);
329 }
330
331 // nb: Users are only expected to call cache.Invalidate() on top-level
332 // instructions, not instructions inside fusion nodes. Therefore we can only
333 // cache top-level instructions; it would not be valid to pass the cache to
334 // SharedMemoryUsageNoCache and use the cache *within* the fusion.
335 auto it_and_inserted = cache->shared_memory_usage.emplace(&instr, -1);
336 auto it = it_and_inserted.first;
337 auto inserted = it_and_inserted.second;
338
339 if (inserted) {
340 it->second = SharedMemoryUsageNoCache(instr);
341 }
342 return it->second;
343 }
344
345 // Codegen'ing unnested reductions requires a lot of registers, so a MOF
346 // combining many of those runs a high risk of spilling.
347 constexpr int64_t kMaxUnnestedReductionOutputsPerFusion = 8;
348
349 // Returns the number of unnested reductions in the instruction output.
NumUnnestedReductionsNoCache(const HloInstruction & instr)350 static int64_t NumUnnestedReductionsNoCache(const HloInstruction& instr) {
351 if (instr.opcode() == HloOpcode::kReduce &&
352 IsReductionFromOrToContiguousDimensions(instr)) {
353 return 1;
354 }
355 if (instr.opcode() == HloOpcode::kFusion) {
356 int64_t sum = 0;
357 for (const HloInstruction* hlo :
358 instr.fused_instructions_computation()->instructions()) {
359 sum += NumUnnestedReductionsNoCache(*hlo);
360 }
361 return sum;
362 }
363 return 0;
364 }
365
NumUnnestedReductions(const HloInstruction & instr,FusionInfoCache * cache)366 static int64_t NumUnnestedReductions(const HloInstruction& instr,
367 FusionInfoCache* cache) {
368 if (!cache) {
369 return NumUnnestedReductionsNoCache(instr);
370 }
371
372 // nb: Users are only expected to call cache.Invalidate() on top-level
373 // instructions, not instructions inside fusion nodes. Therefore we can only
374 // cache top-level instructions; it would not be valid to pass the cache to
375 // NumUnnestedReductionsNoCache and use the cache *within* the fusion.
376 auto it_and_inserted = cache->num_unnested_reductions.emplace(&instr, -1);
377 auto it = it_and_inserted.first;
378 auto inserted = it_and_inserted.second;
379
380 if (inserted) {
381 it->second = NumUnnestedReductionsNoCache(instr);
382 }
383 return it->second;
384 }
385
386 // This function limits the maximum number of operands to a fusion, and the
387 // amount of shared memory which can be consumed by the fusion.
388 //
389 // There's a cap on how many parameters we can pass to a CUDA kernel, but
390 // exactly what that limit is hazy, as it depends on (among other things) how
391 // much GPU constant memory is in use for other purposes.
392 //
393 // Moreover, we don't even know at the point that we're running fusion how many
394 // arguments the CUDA kernel for a fusion node will have: It depends on buffer
395 // assignment, where we will decide which of the fusion's operands live in XLA's
396 // big temp buffer versus in other allocations.
397 //
398 // As a heuristic, we simply cap the number of fusion operands plus outputs at
399 // MaxOperandsAndOutputsPerFusion(). This puts an upper bound on the number of
400 // parameters to the kernel, working around the correctness problem.
401 //
402 // This limit is also often good for performance. In a fusion with many
403 // operands, each GPU thread likely has to do a lot of work, and so possibly
404 // uses a lot of registers, thus limiting occupancy.
405 //
406 // If the fusion is a producer/consumer fusion and instr1 is the
407 // consumer and instr2 is the producer, set is_consumer_producer_fusion
408 // to true to enable more fusion.
FusionFitsInBudget(const HloInstruction & instr1,const HloInstruction & instr2,bool is_consumer_producer_fusion,FusionInfoCache * cache)409 FusionDecision FusionFitsInBudget(const HloInstruction& instr1,
410 const HloInstruction& instr2,
411 bool is_consumer_producer_fusion,
412 FusionInfoCache* cache /*=nullptr*/) {
413 if (SharedMemoryUsage(instr1, cache) + SharedMemoryUsage(instr2, cache) >
414 kSharedMemoryBudgetInBytes) {
415 return FusionDecision{}
416 << "shared memory usage would be over the budget of "
417 << kSharedMemoryBudgetInBytes << "B";
418 }
419
420 if (NumUnnestedReductions(instr1, cache) +
421 NumUnnestedReductions(instr2, cache) >
422 kMaxUnnestedReductionOutputsPerFusion) {
423 return FusionDecision{} << "over " << kMaxUnnestedReductionOutputsPerFusion
424 << " unnested reductions in fusion";
425 }
426
427 // Compute the number of outputs of the (possibly multi-output) fusion node
428 // we're considering creating.
429 //
430 // This isn't precise; we may be off by one if
431 // - We're creating a multi-output fusion out of two non-MOFs. Creating a
432 // MOF adds a new buffer, namely, the tuple buffer.
433 // - We're merging two MOFs. In this case, we should count the tuple buffer
434 // only once.
435 // - WLOG there's an edge from `a` to `b` and `b` is the only consumer of
436 // `a`. In this case the result of `a` is not part of the output of the
437 // fusion.
438 //
439 // But because this is a heuristic and our limit
440 // MaxOperandsAndOutputsPerFusion() is a large value (so +/- 1 doesn't make a
441 // big difference), we ignore this small inaccuracy in favor of simplicity.
442 int64_t num_output_buffers = ShapeUtil::SubshapeCount(instr1.shape()) +
443 ShapeUtil::SubshapeCount(instr2.shape());
444
445 // The new fusion will have no more operands and outputs than
446 // producer_operands + consumer_operands - 1 + num_output_buffers
447 // (minus one because we may be fusing a producer->consumer edge between `a`
448 // and `b`).
449 //
450 // This fact may be enough to let us avoid having to compute the true total
451 // number of operands, which can be expensive.
452 if (instr1.operand_count() + instr2.operand_count() - 1 +
453 num_output_buffers <=
454 MaxOperandsAndOutputsPerFusion()) {
455 return {};
456 } else {
457 VLOG(5) << "Operand count of "
458 << "(" << instr1.ToString() << " ) = " << instr1.operand_count()
459 << " and ( " << instr2.ToString()
460 << " ) = " << instr2.operand_count()
461 << " and num_output_buffers = " << num_output_buffers
462 << " is bigger than the bound of "
463 << MaxOperandsAndOutputsPerFusion();
464 }
465
466 // Compute the precise number of operands to the new fusion.
467 absl::flat_hash_set<const HloInstruction*> operands(instr1.operands().begin(),
468 instr1.operands().end());
469 operands.insert(instr2.operands().begin(), instr2.operands().end());
470 // If there's an edge between `a` and `b`, don't count it: We're fusing that
471 // producer -> consumer relationship.
472 operands.erase(&instr1);
473 operands.erase(&instr2);
474
475 // If we generate the same numbers of inputs and outputs as
476 // before, it won't be bigger after fusion. So accept the fusion.
477 // As this is a consumer_producer fusion, this does not change the
478 // consumer numbers of output. So no need to check it.
479 if (is_consumer_producer_fusion &&
480 operands.size() <= instr1.operands().size()) {
481 return {};
482 }
483
484 // Does the new fusion have more operands and outputs than the max?
485 if (operands.size() + num_output_buffers > MaxOperandsAndOutputsPerFusion()) {
486 return "Number of operands and output buffers is larger than allowed "
487 "budget per fusion";
488 }
489 return {};
490 }
491
CreatesNestedLoop(const HloInstruction & producer,const HloInstruction & consumer)492 bool CreatesNestedLoop(const HloInstruction& producer,
493 const HloInstruction& consumer) {
494 // If producer does not have an instruction that codegens a loop then there is
495 // nothing to do.
496 auto producer_has_loop_codegen = [&](const HloInstruction& instr) {
497 if (producer.opcode() != HloOpcode::kFusion) {
498 return IfFusedReadsElementsMultipleTimes(producer);
499 }
500 for (const auto& instr : producer.fused_instructions()) {
501 if (IfFusedReadsElementsMultipleTimes(*instr)) {
502 return true;
503 }
504 }
505 return false;
506 };
507 if (!producer_has_loop_codegen(producer)) {
508 return false;
509 }
510
511 // If consumer is a non-fusion instruction then we have to check if it
512 // generates a loop.
513 if (consumer.opcode() != HloOpcode::kFusion) {
514 return IfFusedReadsElementsMultipleTimes(consumer);
515 }
516
517 // If consumer is a fusion then we have to check if the output of producer is
518 // used directly or indirectly as an input to an HLO instruction that
519 // generates a loop, i.e. there is a path in the graph from an operand
520 // corresponding to the producer to an HLO instruction generating a loop in
521 // the consumer.
522 for (const HloInstruction* operand : consumer.operands()) {
523 if (operand != &producer) {
524 continue;
525 }
526
527 const HloInstruction* root =
528 consumer.fused_instructions_computation()->parameter_instruction(
529 consumer.operand_index(operand));
530
531 std::stack<const HloInstruction*> dfs;
532 dfs.push(root);
533 absl::flat_hash_set<const HloInstruction*> visited;
534 while (!dfs.empty()) {
535 const HloInstruction* cur = dfs.top();
536 dfs.pop();
537
538 if (visited.contains(cur)) {
539 continue;
540 }
541 visited.insert(cur);
542
543 if (IfFusedReadsElementsMultipleTimes(*cur)) {
544 return true;
545 }
546 for (const auto& user : cur->users()) {
547 if (visited.contains(user)) {
548 continue;
549 }
550 dfs.push(user);
551 }
552 }
553 }
554 return false;
555 }
556
IsFusibleAsMultiOutputFusionRoot(const HloInstruction & instr)557 bool IsFusibleAsMultiOutputFusionRoot(const HloInstruction& instr) {
558 // We can fuse reduces and loop fusions. Elementwise instructions can be fused
559 // with any other instruction.
560 // Note that scatter cannot be the root of a multi-output fusion because
561 // its emitter doesn't support it.
562
563 return instr.IsFusible() &&
564 (IsInputFusibleReduction(instr) ||
565 instr.IsLoopFusion() || // TODO(b/130013493): Use IsLoopFusible here.
566 instr.IsElementwise());
567 }
568
ChooseFusionKind(const HloInstruction &,const HloInstruction & consumer)569 HloInstruction::FusionKind ChooseFusionKind(const HloInstruction& /*producer*/,
570 const HloInstruction& consumer) {
571 return IsInputFusible(consumer) ? HloInstruction::FusionKind::kInput
572 : HloInstruction::FusionKind::kLoop;
573 }
574
IsConsumerTheOnlyNonRootUser(const HloInstruction & instr,const HloInstruction & consumer)575 bool IsConsumerTheOnlyNonRootUser(const HloInstruction& instr,
576 const HloInstruction& consumer) {
577 return absl::c_all_of(instr.users(), [&](const HloInstruction* user) {
578 if (user->opcode() == HloOpcode::kGetTupleElement) {
579 // Skip GTE.
580 return IsConsumerTheOnlyNonRootUser(*user, consumer);
581 }
582 if (user == &consumer) {
583 // `user` is `consumer`.
584 return true;
585 }
586 if (user == user->parent()->root_instruction()) {
587 // Consumed by ROOT.
588 return true;
589 }
590 return false;
591 });
592 }
593
GetInstrCountOfFusible(const HloInstruction & instr)594 size_t GetInstrCountOfFusible(const HloInstruction& instr) {
595 if (instr.opcode() != HloOpcode::kFusion) {
596 return 1;
597 } else {
598 return instr.fused_instruction_count();
599 }
600 }
601
GetOutputsOfFusible(const HloInstruction & instr)602 absl::InlinedVector<const HloInstruction*, 2> GetOutputsOfFusible(
603 const HloInstruction& instr) {
604 if (instr.opcode() != HloOpcode::kFusion) {
605 return {&instr};
606 }
607
608 HloInstruction* root = instr.fused_expression_root();
609 if (root->opcode() != HloOpcode::kTuple) {
610 return {root};
611 } else {
612 auto v = root->operands();
613 return absl::InlinedVector<const HloInstruction*, 2>(v.begin(), v.end());
614 }
615 }
616
GetOutputSizeOfFusible(const HloInstruction & instr)617 size_t GetOutputSizeOfFusible(const HloInstruction& instr) {
618 if (!instr.IsMultiOutputFusion()) {
619 return 1;
620 }
621 const HloInstruction* root = instr.fused_expression_root();
622 return ShapeUtil::TupleElementCount(root->shape());
623 }
624
625 } // namespace gpu
626 } // namespace xla
627