1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/instruction_fusion.h"
17
18 #include <algorithm>
19 #include <list>
20 #include <memory>
21 #include <numeric>
22 #include <vector>
23
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/memory/memory.h"
28 #include "tensorflow/compiler/xla/map_util.h"
29 #include "tensorflow/compiler/xla/service/fusion_queue.h"
30 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
31 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
32 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/platform/logging.h"
35
36 namespace xla {
37 namespace {
38 // These nodes can always be duplicated into consumers, even if
39 // InstructionFusion::may_duplicate_ is false.
40 //
41 // In general these should be nodes that get *cheaper* the more they're
42 // duplicated (and fused into consumers).
43 //
44 // TODO(jlebar): Duplicating instructions when we have a variable called "may
45 // duplicate" that's equal to false is not pretty.
IsAlwaysDuplicable(const HloInstruction & instruction)46 bool IsAlwaysDuplicable(const HloInstruction& instruction) {
47 // We are always willing to duplicate a widening type-conversion instruction
48 // if it means we can fuse the convert into a consumer. This allows the
49 // consumer to read less memory, which is almost always a performance win.
50 return instruction.opcode() == HloOpcode::kConvert &&
51 ShapeUtil::ByteSizeOf(instruction.operand(0)->shape()) <
52 ShapeUtil::ByteSizeOf(instruction.shape());
53 }
54 } // namespace
55
IsExpensive(const HloInstruction & instruction)56 /*static*/ bool InstructionFusion::IsExpensive(
57 const HloInstruction& instruction) {
58 switch (instruction.opcode()) {
59 // Cheap instructions.
60 case HloOpcode::kAdd:
61 case HloOpcode::kAnd:
62 case HloOpcode::kBitcast:
63 case HloOpcode::kBitcastConvert:
64 case HloOpcode::kBroadcast:
65 case HloOpcode::kCeil:
66 case HloOpcode::kClamp:
67 case HloOpcode::kClz:
68 case HloOpcode::kCompare:
69 case HloOpcode::kComplex:
70 case HloOpcode::kConcatenate:
71 case HloOpcode::kConstant:
72 case HloOpcode::kConvert:
73 case HloOpcode::kCopy:
74 case HloOpcode::kDynamicSlice:
75 case HloOpcode::kDynamicUpdateSlice:
76 case HloOpcode::kFloor:
77 case HloOpcode::kGetTupleElement:
78 case HloOpcode::kImag:
79 case HloOpcode::kInfeed:
80 case HloOpcode::kIota:
81 case HloOpcode::kIsFinite:
82 case HloOpcode::kMaximum:
83 case HloOpcode::kMinimum:
84 case HloOpcode::kMultiply:
85 case HloOpcode::kNegate:
86 case HloOpcode::kNot:
87 case HloOpcode::kOr:
88 case HloOpcode::kXor:
89 case HloOpcode::kOutfeed:
90 case HloOpcode::kPad:
91 case HloOpcode::kReal:
92 case HloOpcode::kReducePrecision:
93 case HloOpcode::kReplicaId:
94 case HloOpcode::kReshape:
95 case HloOpcode::kReverse:
96 case HloOpcode::kRoundNearestAfz:
97 case HloOpcode::kSelect:
98 case HloOpcode::kShiftLeft:
99 case HloOpcode::kShiftRightArithmetic:
100 case HloOpcode::kShiftRightLogical:
101 case HloOpcode::kSlice:
102 case HloOpcode::kSubtract:
103 case HloOpcode::kTranspose:
104 case HloOpcode::kTuple:
105 case HloOpcode::kTupleSelect:
106 return false;
107
108 // Cheap instructions for reals, but expensive for complex.
109 case HloOpcode::kAbs:
110 case HloOpcode::kCos:
111 case HloOpcode::kSign:
112 case HloOpcode::kSin:
113 return ShapeUtil::ElementIsComplex(instruction.shape());
114
115 // Expensive instructions or unusual instructions for which fusion is
116 // nonsensical.
117 case HloOpcode::kAddDependency:
118 case HloOpcode::kAfterAll:
119 case HloOpcode::kAtan2:
120 case HloOpcode::kBatchNormGrad:
121 case HloOpcode::kBatchNormInference:
122 case HloOpcode::kBatchNormTraining:
123 case HloOpcode::kCall:
124 case HloOpcode::kCholesky:
125 case HloOpcode::kConditional:
126 case HloOpcode::kConvolution:
127 case HloOpcode::kAllReduce:
128 case HloOpcode::kAllToAll:
129 case HloOpcode::kCollectivePermute:
130 case HloOpcode::kCustomCall:
131 case HloOpcode::kDivide:
132 case HloOpcode::kDomain:
133 case HloOpcode::kDot:
134 case HloOpcode::kExp:
135 case HloOpcode::kExpm1:
136 case HloOpcode::kFft:
137 case HloOpcode::kFusion:
138 case HloOpcode::kGather:
139 case HloOpcode::kLog:
140 case HloOpcode::kLog1p:
141 case HloOpcode::kMap:
142 case HloOpcode::kParameter:
143 case HloOpcode::kPower:
144 case HloOpcode::kRecv:
145 case HloOpcode::kRecvDone:
146 case HloOpcode::kReduce:
147 case HloOpcode::kReduceWindow:
148 case HloOpcode::kRemainder:
149 case HloOpcode::kRng:
150 case HloOpcode::kRsqrt:
151 case HloOpcode::kScatter:
152 case HloOpcode::kSelectAndScatter:
153 case HloOpcode::kSend:
154 case HloOpcode::kSendDone:
155 case HloOpcode::kSort:
156 case HloOpcode::kSqrt:
157 case HloOpcode::kTanh:
158 case HloOpcode::kTrace:
159 case HloOpcode::kTriangularSolve:
160 case HloOpcode::kWhile:
161 case HloOpcode::kGetDimensionSize:
162 return true;
163 }
164
165 return false;
166 }
167
168 // An "effectively at most unary" operation is one that has at most one "large"
169 // input with the others being negligible in terms of memory usage.
170 // We use "has a smaller true rank than the output" as a heuristic
171 // for "negligible" memory usage.
EffectivelyAtMostUnary(HloInstruction * hlo)172 bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) {
173 int64 output_rank = 0;
174 ShapeUtil::ForEachSubshape(
175 hlo->shape(),
176 [&output_rank](const Shape& subshape, const ShapeIndex& shape_index) {
177 if (subshape.IsArray()) {
178 output_rank = std::max(output_rank, ShapeUtil::TrueRank(subshape));
179 }
180 });
181 return absl::c_count_if(
182 hlo->operands(), [output_rank](HloInstruction* operand) {
183 if (operand->opcode() == HloOpcode::kBroadcast ||
184 operand->opcode() == HloOpcode::kIota) {
185 return false;
186 }
187 if (operand->opcode() == HloOpcode::kConstant &&
188 ShapeUtil::IsEffectiveScalar(operand->shape())) {
189 return false;
190 }
191 return ShapeUtil::TrueRank(operand->shape()) >= output_rank;
192 }) <= 1;
193 }
194
CanFuseOnAllPaths(HloInstruction * producer,HloInstruction * consumer,const HloInstructionSet & do_not_fuse,absl::flat_hash_map<std::pair<HloInstruction *,HloInstruction * >,bool> * result_cache)195 bool InstructionFusion::CanFuseOnAllPaths(
196 HloInstruction* producer, HloInstruction* consumer,
197 const HloInstructionSet& do_not_fuse,
198 absl::flat_hash_map<std::pair<HloInstruction*, HloInstruction*>, bool>*
199 result_cache) {
200 if (consumer == producer) {
201 return true;
202 }
203 if (!consumer->IsFusible()) {
204 return false;
205 }
206 auto cache_it = result_cache->find(std::make_pair(producer, consumer));
207 if (cache_it != result_cache->end()) {
208 return cache_it->second;
209 }
210 bool result = true;
211 for (int64 i = 0, e = consumer->operand_count(); i < e; ++i) {
212 auto* consumer_operand = consumer->mutable_operand(i);
213 // If the operand is not on a path to the producer, it doesn't matter
214 // whether it's fusible.
215 if (!reachability_->IsReachable(producer, consumer_operand)) {
216 continue;
217 }
218 if (do_not_fuse.count(consumer_operand) > 0 || !ShouldFuse(consumer, i)) {
219 result = false;
220 break;
221 }
222 // The producer is reachable from consumer_operand which means we need
223 // to be able to fuse consumer_operand into consumer in order for
224 // producer to be fusible into consumer on all paths.
225 // Perform the recursive step: make sure producer can be fused into
226 // consumer_operand on all paths.
227 if (!CanFuseOnAllPaths(producer, consumer_operand, do_not_fuse,
228 result_cache)) {
229 result = false;
230 break;
231 }
232 }
233 result_cache->emplace(std::make_pair(producer, consumer), result);
234 return result;
235 }
236
237 InstructionFusion::HloInstructionSet
ComputeGloballyUnfusible(absl::Span<HloInstruction * const> post_order)238 InstructionFusion::ComputeGloballyUnfusible(
239 absl::Span<HloInstruction* const> post_order) {
240 // Forbid fusion of producers that:
241 // a) Need to be duplicated, unless they can be fused into all consumers
242 // via all paths.
243 // b) Are more than unary, that is, fusing them would likely lead to an
244 // increase in memory bandwidth use.
245 //
246 // Note that if we allow fusion by these global rules, we may still forbid
247 // fusing operations that require duplication later depending on
248 // is_expensive_().
249 HloInstructionSet do_not_duplicate;
250 absl::flat_hash_map<std::pair<HloInstruction*, HloInstruction*>, bool>
251 can_fuse_on_all_paths_result_cache;
252 for (HloInstruction* consumer : post_order) {
253 for (HloInstruction* producer : consumer->operands()) {
254 if (do_not_duplicate.count(producer) > 0) {
255 continue;
256 }
257
258 // If the producer is effectively not more than unary, duplicating it
259 // will not increase the number of relevant inputs read, as the fusion
260 // node will only need to read at most 1 relevant input (the input of
261 // the producer). In that case, we do not forbid fusion of the operation
262 // here.
263 if (EffectivelyAtMostUnary(producer)) {
264 continue;
265 }
266
267 // If the total size of the inputs is less than or equal to the total size
268 // of the outputs for the producer then duplicating it won't increase the
269 // memory traffic. In that case, we do not forbid fusion of the operation
270 // here.
271 auto total_size = [](const Shape& shape) {
272 int64 size = 0;
273 ShapeUtil::ForEachSubshape(
274 shape,
275 [&size](const Shape& subshape, const ShapeIndex& shape_index) {
276 if (subshape.IsArray()) {
277 size += ShapeUtil::ElementsIn(subshape);
278 }
279 });
280 return size;
281 };
282 int64 operands_size = 0;
283 for (const HloInstruction* op : producer->operands()) {
284 operands_size += total_size(op->shape());
285 }
286 if (operands_size <= total_size(producer->shape())) {
287 continue;
288 }
289
290 // Otherwise we will forbid fusing the op unless we can fuse it into
291 // all of its consumers on all paths.
292 //
293 // That means, that for:
294 // A --> B (fusible)
295 // \-> C (non-fusible)
296 // A will be not allowed to be fused into B, as it cannot be fused into C.
297 //
298 // Similarly, for:
299 // A -------------> B
300 // \-> C -> D -/
301 // If:
302 // - A is fusible into B and C, and D is fusible into B
303 // - C is *not* fusible into D
304 // A will be not allowed to be fused into B, as it cannot be fused via
305 // all paths.
306 if (producer->IsFusible() &&
307 CanFuseOnAllPaths(producer, consumer, do_not_duplicate,
308 &can_fuse_on_all_paths_result_cache)) {
309 continue;
310 }
311 do_not_duplicate.insert(producer);
312 }
313 }
314
315 return do_not_duplicate;
316 }
317
318 namespace {
319
320 // A FusionQueue that uses reverse post order.
321 //
322 // We want to be able to remove arbitrary instructions from the post order and
323 // also compare positions of instructions in the post order. To make this
324 // possible, create vector of instructions in post order and create a map from
325 // HloInstruction* to the instruction's index in the vector. An instruction is
326 // "removed" from the vector by setting it's element to nullptr.
327 class ReversePostOrderFusionQueue : public FusionQueue {
328 public:
ReversePostOrderFusionQueue(HloComputation * computation)329 explicit ReversePostOrderFusionQueue(HloComputation* computation) {
330 post_order_ = computation->MakeInstructionPostOrder();
331
332 for (size_t i = 0; i < post_order_.size(); ++i) {
333 InsertOrDie(&post_order_index_, post_order_[i], i);
334 }
335 }
336
337 std::pair<HloInstruction*, std::vector<int64>>
DequeueNextInstructionAndOperandsToFuseInOrder()338 DequeueNextInstructionAndOperandsToFuseInOrder() override {
339 // Instructions are "removed" from the post order by nulling out the element
340 // in the vector, so if the pointer is null, continue to the next
341 // instruction in the sort.
342 while (!post_order_.empty() && post_order_.back() == nullptr) {
343 post_order_.pop_back();
344 }
345 if (post_order_.empty()) {
346 return std::pair<HloInstruction*, std::vector<int64>>{nullptr, {}};
347 }
348 // We want to iterate in reverse post order, so remove from the back of the
349 // vector.
350 HloInstruction* instruction = post_order_.back();
351 post_order_.pop_back();
352
353 CHECK(instruction != nullptr);
354 // Remove instruction from the index map to ensure the vector and map stay
355 // consistent.
356 post_order_index_.erase(instruction);
357
358 // Consider each operand of this instruction for fusion into this
359 // instruction. We want to consider the operands in a particular order to
360 // avoid creating duplicate instruction clones in the fusion instruction.
361 // For example, consider the following expression:
362 //
363 // A = ...
364 // B = op(A)
365 // C = op(A, B)
366 //
367 // If we are considering the operands of C for fusion into C. We might
368 // fuse A or B first. If we fuse A first, we get:
369 //
370 // A = ...
371 // B = op(A)
372 // C_fusion = { A' = ...
373 // C' = op(A', B) }
374 //
375 // Where A' and C' are clones of A and C, respectively. Now only B is an
376 // operand of the fusion instruction C_fusion, so then we fuse B:
377 //
378 // A = ...
379 // B = op(A)
380 // C_fusion = { A' = ...
381 // B' = op(A)
382 // C' = op(A', B') }
383 //
384 // Now A is an operand of C_fusion again, so we then fuse A (again!):
385 //
386 // A = ...
387 // B = op(A)
388 // C_fusion = { A' = ...
389 // A" = ..
390 // B' = op(A")
391 // C' = op(A', B') }
392 //
393 // We prevent this duplication by considering the operands in the order
394 // they appear int the queue. In the example, this ensures that B will be
395 // considered before A.
396 //
397 // We store the original indices of the operands to pass to ShouldFuse.
398 std::vector<int64> sorted_operand_numbers;
399 sorted_operand_numbers.reserve(instruction->operands().size());
400 for (int i = 0; i < instruction->operands().size(); ++i) {
401 // This will happen if we have two possible instructions to fuse the
402 // same operand into; once the operand is fused into one instruction,
403 // the other instruction will get a new get-tuple-element as its
404 // operand, which is not in the queue.
405 // TODO(tjoerg): Look into fusing past these multi-output fuse points.
406 if (!ContainsKey(post_order_index_, instruction->mutable_operand(i))) {
407 continue;
408 }
409 sorted_operand_numbers.push_back(i);
410 }
411 absl::c_sort(
412 sorted_operand_numbers, [&](int64 i, int64 j) {
413 // Instructions with higher priority in the queue come first.
414 return (
415 FindOrDie(post_order_index_, instruction->mutable_operand(i)) >
416 FindOrDie(post_order_index_, instruction->mutable_operand(j)));
417 });
418 return std::make_pair(instruction, sorted_operand_numbers);
419 }
420
OnFusingInstruction(HloInstruction * fusion,HloInstruction * original_producer,HloInstruction * original_consumer)421 void OnFusingInstruction(HloInstruction* fusion,
422 HloInstruction* original_producer,
423 HloInstruction* original_consumer) override {
424 // Fusing an instruction into a fusion instruction can change the operand
425 // set of the fusion instruction. For simplicity just re-enqueue the
426 // instruction and reconsider it for further fusion in the next iteration.
427 InsertOrDie(&post_order_index_, fusion, post_order_.size());
428 post_order_.push_back(fusion);
429 }
430
RemoveInstruction(HloInstruction * instruction)431 void RemoveInstruction(HloInstruction* instruction) override {
432 post_order_[FindOrDie(post_order_index_, instruction)] = nullptr;
433 post_order_index_.erase(instruction);
434 }
435
436 private:
437 std::vector<HloInstruction*> post_order_;
438 absl::flat_hash_map<HloInstruction*, int> post_order_index_;
439 };
440
441 } // namespace
442
GetFusionQueue(HloComputation * computation)443 std::unique_ptr<FusionQueue> InstructionFusion::GetFusionQueue(
444 HloComputation* computation) {
445 return absl::make_unique<ReversePostOrderFusionQueue>(computation);
446 }
447
Run(HloModule * module)448 StatusOr<bool> InstructionFusion::Run(HloModule* module) {
449 VLOG(2) << "Before instruction fusion:";
450 XLA_VLOG_LINES(2, module->ToString());
451
452 bool changed = false;
453 module_ = module;
454 for (auto* computation : module->MakeNonfusionComputations()) {
455 CHECK(!computation->IsFusionComputation());
456 computation_ = computation;
457 reachability_ = HloReachabilityMap::Build(computation_);
458
459 HloInstructionSet do_not_duplicate;
460 // If we allow duplications, we need to compute which instructions we do not
461 // want to duplicate based on a global analysis of the graph.
462 if (may_duplicate_) {
463 do_not_duplicate =
464 ComputeGloballyUnfusible(computation_->MakeInstructionPostOrder());
465 }
466 auto fusion_queue = GetFusionQueue(computation_);
467
468 // Instruction fusion effectively fuses edges in the computation graph
469 // (producer instruction -> consumer instruction) so we iterate over all
470 // edges. When we fuse an edge, we create a copy of the producer inside the
471 // fusion instruction.
472 while (true) {
473 auto next_entry =
474 fusion_queue->DequeueNextInstructionAndOperandsToFuseInOrder();
475 auto instruction = next_entry.first;
476 if (instruction == nullptr) {
477 break;
478 }
479
480 if (!instruction->IsFusible() &&
481 instruction->opcode() != HloOpcode::kFusion) {
482 continue;
483 }
484
485 std::vector<int64>& sorted_operand_numbers = next_entry.second;
486
487 for (int64 i : sorted_operand_numbers) {
488 HloInstruction* operand = instruction->mutable_operand(i);
489
490 if (!operand->IsFusible()) {
491 continue;
492 }
493
494 HloInstruction* fusion_instruction;
495 // Try "regular" fusion if the operand may be duplicated. Otherwise,
496 // perform multi-output fusion, unless this creates a cycle.
497 if (do_not_duplicate.count(operand) == 0 &&
498 ShouldFuse(instruction, i)) {
499 fusion_queue->PreFusion(operand, instruction);
500 fusion_instruction = Fuse(operand, instruction);
501 } else if (ShouldFuseIntoMultiOutput(instruction, i) &&
502 !MultiOutputFusionCreatesCycle(operand, instruction)) {
503 fusion_queue->PreFusion(operand, instruction);
504 fusion_instruction = FuseIntoMultiOutput(operand, instruction);
505 } else {
506 continue;
507 }
508
509 fusion_queue->OnFusingInstruction(fusion_instruction, operand,
510 instruction);
511 changed = true;
512
513 if (operand->user_count() == 0) {
514 do_not_duplicate.erase(operand);
515 // Operand is now dead. Remove from queue.
516 fusion_queue->RemoveInstruction(operand);
517 // Remove from computation.
518 TF_RETURN_IF_ERROR(computation_->RemoveInstruction(operand));
519 }
520
521 if (fusion_instruction != instruction) {
522 do_not_duplicate.erase(instruction);
523 }
524 break;
525 }
526 }
527 }
528
529 VLOG(2) << "After instruction fusion:";
530 XLA_VLOG_LINES(2, module->ToString());
531
532 return changed;
533 }
534
AddFusionInstruction(HloInstruction * producer,HloInstruction * consumer)535 HloInstruction* InstructionFusion::AddFusionInstruction(
536 HloInstruction* producer, HloInstruction* consumer) {
537 HloInstruction* fusion_instruction;
538 auto kind = ChooseKind(producer, consumer);
539 if (consumer->opcode() == HloOpcode::kFusion) {
540 fusion_instruction = consumer;
541 if (kind != fusion_instruction->fusion_kind()) {
542 fusion_instruction->set_fusion_kind(kind);
543 }
544 } else {
545 fusion_instruction = computation_->AddInstruction(
546 HloInstruction::CreateFusion(consumer->shape(), kind, consumer));
547 TF_CHECK_OK(computation_->ReplaceInstruction(consumer, fusion_instruction));
548 }
549 return fusion_instruction;
550 }
551
Fuse(HloInstruction * producer,HloInstruction * consumer)552 HloInstruction* InstructionFusion::Fuse(HloInstruction* producer,
553 HloInstruction* consumer) {
554 VLOG(2) << "Fusing " << producer->ToString() << " into "
555 << consumer->ToString();
556 HloInstruction* fusion_instruction = AddFusionInstruction(producer, consumer);
557 fusion_instruction->FuseInstruction(producer);
558 return fusion_instruction;
559 }
560
FuseIntoMultiOutput(HloInstruction * producer,HloInstruction * consumer)561 HloInstruction* InstructionFusion::FuseIntoMultiOutput(
562 HloInstruction* producer, HloInstruction* consumer) {
563 VLOG(2) << "Multi-output fusing " << producer->ToString() << " into "
564 << consumer->ToString();
565 HloInstruction* fusion_instruction = AddFusionInstruction(producer, consumer);
566 fusion_instruction->FuseInstructionIntoMultiOutput(producer);
567 return fusion_instruction;
568 }
569
MultiOutputFusionCreatesCycle(HloInstruction * producer,HloInstruction * consumer)570 bool InstructionFusion::MultiOutputFusionCreatesCycle(
571 HloInstruction* producer, HloInstruction* consumer) {
572 absl::flat_hash_set<int> operands;
573 for (const HloInstruction* operand : consumer->operands()) {
574 if (operand == producer) {
575 continue;
576 }
577
578 // If the reachability map already contains the producer and the operand of
579 // the consumer, and the producer can reach the operand, then we know for
580 // sure MultiOutputFusion would create a cycle. If not, we need to do a DFS
581 // traversal of the computation to verify that this multioutput fusion would
582 // not create a cycle.
583 if (reachability_->IsPresent(producer) &&
584 reachability_->IsPresent(operand) &&
585 reachability_->IsReachable(producer, operand)) {
586 return true;
587 }
588 operands.insert(operand->unique_id());
589 }
590
591 // Do a DFS on the producer to see if any of the other consumer operands are
592 // reachable in the current state of the graph.
593 std::vector<HloInstruction*> worklist = producer->users();
594 absl::flat_hash_set<int> visits;
595 while (!worklist.empty()) {
596 const HloInstruction* user = worklist.back();
597 worklist.pop_back();
598 if (operands.count(user->unique_id()) != 0) {
599 return true;
600 }
601 if (visits.count(user->unique_id()) == 0) {
602 visits.insert(user->unique_id());
603 worklist.insert(worklist.end(), user->users().begin(),
604 user->users().end());
605 }
606 }
607 return false;
608 }
609
ShouldFuse(HloInstruction * consumer,int64 operand_index)610 bool InstructionFusion::ShouldFuse(HloInstruction* consumer,
611 int64 operand_index) {
612 HloInstruction* producer = consumer->mutable_operand(operand_index);
613
614 // Cost condition: don't duplicate expensive instructions.
615 if (FusionWouldDuplicate(*producer, *consumer) &&
616 (!may_duplicate_ || is_expensive_(*producer)) &&
617 !IsAlwaysDuplicable(*producer)) {
618 return false;
619 }
620
621 if (consumer->opcode() == HloOpcode::kFusion &&
622 consumer->fusion_kind() != HloInstruction::FusionKind::kLoop &&
623 consumer->fusion_kind() != HloInstruction::FusionKind::kInput &&
624 consumer->fusion_kind() != HloInstruction::FusionKind::kOutput) {
625 return false;
626 }
627
628 if (producer->CouldBeBitcast() &&
629 // We can't fuse parameters anyhow, so we leave the user unfused to become
630 // a bitcast. If the operand is not a parameter, we would break a
631 // potential fusion to make it a bitcast, which is not so clear a win.
632 producer->operand(0)->opcode() == HloOpcode::kParameter) {
633 return false;
634 }
635
636 return true;
637 }
638
ChooseKind(const HloInstruction * producer,const HloInstruction * consumer)639 HloInstruction::FusionKind InstructionFusion::ChooseKind(
640 const HloInstruction* producer, const HloInstruction* consumer) {
641 return HloInstruction::FusionKind::kLoop;
642 }
643
644 } // namespace xla
645