• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/instruction_fusion.h"
17 
18 #include <algorithm>
19 #include <list>
20 #include <memory>
21 #include <numeric>
22 #include <vector>
23 
24 #include "absl/algorithm/container.h"
25 #include "absl/container/flat_hash_map.h"
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/memory/memory.h"
28 #include "tensorflow/compiler/xla/map_util.h"
29 #include "tensorflow/compiler/xla/service/fusion_queue.h"
30 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
31 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
32 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
33 #include "tensorflow/core/lib/core/errors.h"
34 #include "tensorflow/core/platform/logging.h"
35 
36 namespace xla {
37 namespace {
38 // These nodes can always be duplicated into consumers, even if
39 // InstructionFusion::may_duplicate_ is false.
40 //
41 // In general these should be nodes that get *cheaper* the more they're
42 // duplicated (and fused into consumers).
43 //
44 // TODO(jlebar): Duplicating instructions when we have a variable called "may
45 // duplicate" that's equal to false is not pretty.
IsAlwaysDuplicable(const HloInstruction & instruction)46 bool IsAlwaysDuplicable(const HloInstruction& instruction) {
47   // We are always willing to duplicate a widening type-conversion instruction
48   // if it means we can fuse the convert into a consumer.  This allows the
49   // consumer to read less memory, which is almost always a performance win.
50   return instruction.opcode() == HloOpcode::kConvert &&
51          ShapeUtil::ByteSizeOf(instruction.operand(0)->shape()) <
52              ShapeUtil::ByteSizeOf(instruction.shape());
53 }
54 }  // namespace
55 
IsExpensive(const HloInstruction & instruction)56 /*static*/ bool InstructionFusion::IsExpensive(
57     const HloInstruction& instruction) {
58   switch (instruction.opcode()) {
59     // Cheap instructions.
60     case HloOpcode::kAdd:
61     case HloOpcode::kAnd:
62     case HloOpcode::kBitcast:
63     case HloOpcode::kBitcastConvert:
64     case HloOpcode::kBroadcast:
65     case HloOpcode::kCeil:
66     case HloOpcode::kClamp:
67     case HloOpcode::kClz:
68     case HloOpcode::kCompare:
69     case HloOpcode::kComplex:
70     case HloOpcode::kConcatenate:
71     case HloOpcode::kConstant:
72     case HloOpcode::kConvert:
73     case HloOpcode::kCopy:
74     case HloOpcode::kDynamicSlice:
75     case HloOpcode::kDynamicUpdateSlice:
76     case HloOpcode::kFloor:
77     case HloOpcode::kGetTupleElement:
78     case HloOpcode::kImag:
79     case HloOpcode::kInfeed:
80     case HloOpcode::kIota:
81     case HloOpcode::kIsFinite:
82     case HloOpcode::kMaximum:
83     case HloOpcode::kMinimum:
84     case HloOpcode::kMultiply:
85     case HloOpcode::kNegate:
86     case HloOpcode::kNot:
87     case HloOpcode::kOr:
88     case HloOpcode::kXor:
89     case HloOpcode::kOutfeed:
90     case HloOpcode::kPad:
91     case HloOpcode::kReal:
92     case HloOpcode::kReducePrecision:
93     case HloOpcode::kReplicaId:
94     case HloOpcode::kReshape:
95     case HloOpcode::kReverse:
96     case HloOpcode::kRoundNearestAfz:
97     case HloOpcode::kSelect:
98     case HloOpcode::kShiftLeft:
99     case HloOpcode::kShiftRightArithmetic:
100     case HloOpcode::kShiftRightLogical:
101     case HloOpcode::kSlice:
102     case HloOpcode::kSubtract:
103     case HloOpcode::kTranspose:
104     case HloOpcode::kTuple:
105     case HloOpcode::kTupleSelect:
106       return false;
107 
108     // Cheap instructions for reals, but expensive for complex.
109     case HloOpcode::kAbs:
110     case HloOpcode::kCos:
111     case HloOpcode::kSign:
112     case HloOpcode::kSin:
113       return ShapeUtil::ElementIsComplex(instruction.shape());
114 
115     // Expensive instructions or unusual instructions for which fusion is
116     // nonsensical.
117     case HloOpcode::kAddDependency:
118     case HloOpcode::kAfterAll:
119     case HloOpcode::kAtan2:
120     case HloOpcode::kBatchNormGrad:
121     case HloOpcode::kBatchNormInference:
122     case HloOpcode::kBatchNormTraining:
123     case HloOpcode::kCall:
124     case HloOpcode::kCholesky:
125     case HloOpcode::kConditional:
126     case HloOpcode::kConvolution:
127     case HloOpcode::kAllReduce:
128     case HloOpcode::kAllToAll:
129     case HloOpcode::kCollectivePermute:
130     case HloOpcode::kCustomCall:
131     case HloOpcode::kDivide:
132     case HloOpcode::kDomain:
133     case HloOpcode::kDot:
134     case HloOpcode::kExp:
135     case HloOpcode::kExpm1:
136     case HloOpcode::kFft:
137     case HloOpcode::kFusion:
138     case HloOpcode::kGather:
139     case HloOpcode::kLog:
140     case HloOpcode::kLog1p:
141     case HloOpcode::kMap:
142     case HloOpcode::kParameter:
143     case HloOpcode::kPower:
144     case HloOpcode::kRecv:
145     case HloOpcode::kRecvDone:
146     case HloOpcode::kReduce:
147     case HloOpcode::kReduceWindow:
148     case HloOpcode::kRemainder:
149     case HloOpcode::kRng:
150     case HloOpcode::kRsqrt:
151     case HloOpcode::kScatter:
152     case HloOpcode::kSelectAndScatter:
153     case HloOpcode::kSend:
154     case HloOpcode::kSendDone:
155     case HloOpcode::kSort:
156     case HloOpcode::kSqrt:
157     case HloOpcode::kTanh:
158     case HloOpcode::kTrace:
159     case HloOpcode::kTriangularSolve:
160     case HloOpcode::kWhile:
161     case HloOpcode::kGetDimensionSize:
162       return true;
163   }
164 
165   return false;
166 }
167 
168 // An "effectively at most unary" operation is one that has at most one "large"
169 // input with the others being negligible in terms of memory usage.
170 // We use "has a smaller true rank than the output" as a heuristic
171 // for "negligible" memory usage.
EffectivelyAtMostUnary(HloInstruction * hlo)172 bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) {
173   int64 output_rank = 0;
174   ShapeUtil::ForEachSubshape(
175       hlo->shape(),
176       [&output_rank](const Shape& subshape, const ShapeIndex& shape_index) {
177         if (subshape.IsArray()) {
178           output_rank = std::max(output_rank, ShapeUtil::TrueRank(subshape));
179         }
180       });
181   return absl::c_count_if(
182              hlo->operands(), [output_rank](HloInstruction* operand) {
183                if (operand->opcode() == HloOpcode::kBroadcast ||
184                    operand->opcode() == HloOpcode::kIota) {
185                  return false;
186                }
187                if (operand->opcode() == HloOpcode::kConstant &&
188                    ShapeUtil::IsEffectiveScalar(operand->shape())) {
189                  return false;
190                }
191                return ShapeUtil::TrueRank(operand->shape()) >= output_rank;
192              }) <= 1;
193 }
194 
CanFuseOnAllPaths(HloInstruction * producer,HloInstruction * consumer,const HloInstructionSet & do_not_fuse,absl::flat_hash_map<std::pair<HloInstruction *,HloInstruction * >,bool> * result_cache)195 bool InstructionFusion::CanFuseOnAllPaths(
196     HloInstruction* producer, HloInstruction* consumer,
197     const HloInstructionSet& do_not_fuse,
198     absl::flat_hash_map<std::pair<HloInstruction*, HloInstruction*>, bool>*
199         result_cache) {
200   if (consumer == producer) {
201     return true;
202   }
203   if (!consumer->IsFusible()) {
204     return false;
205   }
206   auto cache_it = result_cache->find(std::make_pair(producer, consumer));
207   if (cache_it != result_cache->end()) {
208     return cache_it->second;
209   }
210   bool result = true;
211   for (int64 i = 0, e = consumer->operand_count(); i < e; ++i) {
212     auto* consumer_operand = consumer->mutable_operand(i);
213     // If the operand is not on a path to the producer, it doesn't matter
214     // whether it's fusible.
215     if (!reachability_->IsReachable(producer, consumer_operand)) {
216       continue;
217     }
218     if (do_not_fuse.count(consumer_operand) > 0 || !ShouldFuse(consumer, i)) {
219       result = false;
220       break;
221     }
222     // The producer is reachable from consumer_operand which means we need
223     // to be able to fuse consumer_operand into consumer in order for
224     // producer to be fusible into consumer on all paths.
225     // Perform the recursive step: make sure producer can be fused into
226     // consumer_operand on all paths.
227     if (!CanFuseOnAllPaths(producer, consumer_operand, do_not_fuse,
228                            result_cache)) {
229       result = false;
230       break;
231     }
232   }
233   result_cache->emplace(std::make_pair(producer, consumer), result);
234   return result;
235 }
236 
237 InstructionFusion::HloInstructionSet
ComputeGloballyUnfusible(absl::Span<HloInstruction * const> post_order)238 InstructionFusion::ComputeGloballyUnfusible(
239     absl::Span<HloInstruction* const> post_order) {
240   // Forbid fusion of producers that:
241   // a) Need to be duplicated, unless they can be fused into all consumers
242   //    via all paths.
243   // b) Are more than unary, that is, fusing them would likely lead to an
244   //    increase in memory bandwidth use.
245   //
246   // Note that if we allow fusion by these global rules, we may still forbid
247   // fusing operations that require duplication later depending on
248   // is_expensive_().
249   HloInstructionSet do_not_duplicate;
250   absl::flat_hash_map<std::pair<HloInstruction*, HloInstruction*>, bool>
251       can_fuse_on_all_paths_result_cache;
252   for (HloInstruction* consumer : post_order) {
253     for (HloInstruction* producer : consumer->operands()) {
254       if (do_not_duplicate.count(producer) > 0) {
255         continue;
256       }
257 
258       // If the producer is effectively not more than unary, duplicating it
259       // will not increase the number of relevant inputs read, as the fusion
260       // node will only need to read at most 1 relevant input (the input of
261       // the producer). In that case, we do not forbid fusion of the operation
262       // here.
263       if (EffectivelyAtMostUnary(producer)) {
264         continue;
265       }
266 
267       // If the total size of the inputs is less than or equal to the total size
268       // of the outputs for the producer then duplicating it won't increase the
269       // memory traffic. In that case, we do not forbid fusion of the operation
270       // here.
271       auto total_size = [](const Shape& shape) {
272         int64 size = 0;
273         ShapeUtil::ForEachSubshape(
274             shape,
275             [&size](const Shape& subshape, const ShapeIndex& shape_index) {
276               if (subshape.IsArray()) {
277                 size += ShapeUtil::ElementsIn(subshape);
278               }
279             });
280         return size;
281       };
282       int64 operands_size = 0;
283       for (const HloInstruction* op : producer->operands()) {
284         operands_size += total_size(op->shape());
285       }
286       if (operands_size <= total_size(producer->shape())) {
287         continue;
288       }
289 
290       // Otherwise we will forbid fusing the op unless we can fuse it into
291       // all of its consumers on all paths.
292       //
293       // That means, that for:
294       // A --> B (fusible)
295       //   \-> C (non-fusible)
296       // A will be not allowed to be fused into B, as it cannot be fused into C.
297       //
298       // Similarly, for:
299       // A -------------> B
300       //   \-> C -> D -/
301       // If:
302       // - A is fusible into B and C, and D is fusible into B
303       // - C is *not* fusible into D
304       // A will be not allowed to be fused into B, as it cannot be fused via
305       // all paths.
306       if (producer->IsFusible() &&
307           CanFuseOnAllPaths(producer, consumer, do_not_duplicate,
308                             &can_fuse_on_all_paths_result_cache)) {
309         continue;
310       }
311       do_not_duplicate.insert(producer);
312     }
313   }
314 
315   return do_not_duplicate;
316 }
317 
318 namespace {
319 
320 // A FusionQueue that uses reverse post order.
321 //
322 // We want to be able to remove arbitrary instructions from the post order and
323 // also compare positions of instructions in the post order. To make this
324 // possible, create vector of instructions in post order and create a map from
325 // HloInstruction* to the instruction's index in the vector. An instruction is
326 // "removed" from the vector by setting it's element to nullptr.
327 class ReversePostOrderFusionQueue : public FusionQueue {
328  public:
ReversePostOrderFusionQueue(HloComputation * computation)329   explicit ReversePostOrderFusionQueue(HloComputation* computation) {
330     post_order_ = computation->MakeInstructionPostOrder();
331 
332     for (size_t i = 0; i < post_order_.size(); ++i) {
333       InsertOrDie(&post_order_index_, post_order_[i], i);
334     }
335   }
336 
337   std::pair<HloInstruction*, std::vector<int64>>
DequeueNextInstructionAndOperandsToFuseInOrder()338   DequeueNextInstructionAndOperandsToFuseInOrder() override {
339     // Instructions are "removed" from the post order by nulling out the element
340     // in the vector, so if the pointer is null, continue to the next
341     // instruction in the sort.
342     while (!post_order_.empty() && post_order_.back() == nullptr) {
343       post_order_.pop_back();
344     }
345     if (post_order_.empty()) {
346       return std::pair<HloInstruction*, std::vector<int64>>{nullptr, {}};
347     }
348     // We want to iterate in reverse post order, so remove from the back of the
349     // vector.
350     HloInstruction* instruction = post_order_.back();
351     post_order_.pop_back();
352 
353     CHECK(instruction != nullptr);
354     // Remove instruction from the index map to ensure the vector and map stay
355     // consistent.
356     post_order_index_.erase(instruction);
357 
358     // Consider each operand of this instruction for fusion into this
359     // instruction. We want to consider the operands in a particular order to
360     // avoid creating duplicate instruction clones in the fusion instruction.
361     // For example, consider the following expression:
362     //
363     //   A = ...
364     //   B = op(A)
365     //   C = op(A, B)
366     //
367     // If we are considering the operands of C for fusion into C. We might
368     // fuse A or B first. If we fuse A first, we get:
369     //
370     //   A = ...
371     //   B = op(A)
372     //   C_fusion = { A' = ...
373     //                C' = op(A', B) }
374     //
375     // Where A' and C' are clones of A and C, respectively. Now only B is an
376     // operand of the fusion instruction C_fusion, so then we fuse B:
377     //
378     //   A = ...
379     //   B = op(A)
380     //   C_fusion = { A' = ...
381     //                B' = op(A)
382     //                C' = op(A', B') }
383     //
384     // Now A is an operand of C_fusion again, so we then fuse A (again!):
385     //
386     //   A = ...
387     //   B = op(A)
388     //   C_fusion = { A' = ...
389     //                A" = ..
390     //                B' = op(A")
391     //                C' = op(A', B') }
392     //
393     // We prevent this duplication by considering the operands in the order
394     // they appear int the queue. In the example, this ensures that B will be
395     // considered before A.
396     //
397     // We store the original indices of the operands to pass to ShouldFuse.
398     std::vector<int64> sorted_operand_numbers;
399     sorted_operand_numbers.reserve(instruction->operands().size());
400     for (int i = 0; i < instruction->operands().size(); ++i) {
401       // This will happen if we have two possible instructions to fuse the
402       // same operand into; once the operand is fused into one instruction,
403       // the other instruction will get a new get-tuple-element as its
404       // operand, which is not in the queue.
405       // TODO(tjoerg): Look into fusing past these multi-output fuse points.
406       if (!ContainsKey(post_order_index_, instruction->mutable_operand(i))) {
407         continue;
408       }
409       sorted_operand_numbers.push_back(i);
410     }
411     absl::c_sort(
412         sorted_operand_numbers, [&](int64 i, int64 j) {
413           // Instructions with higher priority in the queue come first.
414           return (
415               FindOrDie(post_order_index_, instruction->mutable_operand(i)) >
416               FindOrDie(post_order_index_, instruction->mutable_operand(j)));
417         });
418     return std::make_pair(instruction, sorted_operand_numbers);
419   }
420 
OnFusingInstruction(HloInstruction * fusion,HloInstruction * original_producer,HloInstruction * original_consumer)421   void OnFusingInstruction(HloInstruction* fusion,
422                            HloInstruction* original_producer,
423                            HloInstruction* original_consumer) override {
424     // Fusing an instruction into a fusion instruction can change the operand
425     // set of the fusion instruction. For simplicity just re-enqueue the
426     // instruction and reconsider it for further fusion in the next iteration.
427     InsertOrDie(&post_order_index_, fusion, post_order_.size());
428     post_order_.push_back(fusion);
429   }
430 
RemoveInstruction(HloInstruction * instruction)431   void RemoveInstruction(HloInstruction* instruction) override {
432     post_order_[FindOrDie(post_order_index_, instruction)] = nullptr;
433     post_order_index_.erase(instruction);
434   }
435 
436  private:
437   std::vector<HloInstruction*> post_order_;
438   absl::flat_hash_map<HloInstruction*, int> post_order_index_;
439 };
440 
441 }  // namespace
442 
GetFusionQueue(HloComputation * computation)443 std::unique_ptr<FusionQueue> InstructionFusion::GetFusionQueue(
444     HloComputation* computation) {
445   return absl::make_unique<ReversePostOrderFusionQueue>(computation);
446 }
447 
Run(HloModule * module)448 StatusOr<bool> InstructionFusion::Run(HloModule* module) {
449   VLOG(2) << "Before instruction fusion:";
450   XLA_VLOG_LINES(2, module->ToString());
451 
452   bool changed = false;
453   module_ = module;
454   for (auto* computation : module->MakeNonfusionComputations()) {
455     CHECK(!computation->IsFusionComputation());
456     computation_ = computation;
457     reachability_ = HloReachabilityMap::Build(computation_);
458 
459     HloInstructionSet do_not_duplicate;
460     // If we allow duplications, we need to compute which instructions we do not
461     // want to duplicate based on a global analysis of the graph.
462     if (may_duplicate_) {
463       do_not_duplicate =
464           ComputeGloballyUnfusible(computation_->MakeInstructionPostOrder());
465     }
466     auto fusion_queue = GetFusionQueue(computation_);
467 
468     // Instruction fusion effectively fuses edges in the computation graph
469     // (producer instruction -> consumer instruction) so we iterate over all
470     // edges. When we fuse an edge, we create a copy of the producer inside the
471     // fusion instruction.
472     while (true) {
473       auto next_entry =
474           fusion_queue->DequeueNextInstructionAndOperandsToFuseInOrder();
475       auto instruction = next_entry.first;
476       if (instruction == nullptr) {
477         break;
478       }
479 
480       if (!instruction->IsFusible() &&
481           instruction->opcode() != HloOpcode::kFusion) {
482         continue;
483       }
484 
485       std::vector<int64>& sorted_operand_numbers = next_entry.second;
486 
487       for (int64 i : sorted_operand_numbers) {
488         HloInstruction* operand = instruction->mutable_operand(i);
489 
490         if (!operand->IsFusible()) {
491           continue;
492         }
493 
494         HloInstruction* fusion_instruction;
495         // Try "regular" fusion if the operand may be duplicated. Otherwise,
496         // perform multi-output fusion, unless this creates a cycle.
497         if (do_not_duplicate.count(operand) == 0 &&
498             ShouldFuse(instruction, i)) {
499           fusion_queue->PreFusion(operand, instruction);
500           fusion_instruction = Fuse(operand, instruction);
501         } else if (ShouldFuseIntoMultiOutput(instruction, i) &&
502                    !MultiOutputFusionCreatesCycle(operand, instruction)) {
503           fusion_queue->PreFusion(operand, instruction);
504           fusion_instruction = FuseIntoMultiOutput(operand, instruction);
505         } else {
506           continue;
507         }
508 
509         fusion_queue->OnFusingInstruction(fusion_instruction, operand,
510                                           instruction);
511         changed = true;
512 
513         if (operand->user_count() == 0) {
514           do_not_duplicate.erase(operand);
515           // Operand is now dead. Remove from queue.
516           fusion_queue->RemoveInstruction(operand);
517           // Remove from computation.
518           TF_RETURN_IF_ERROR(computation_->RemoveInstruction(operand));
519         }
520 
521         if (fusion_instruction != instruction) {
522           do_not_duplicate.erase(instruction);
523         }
524         break;
525       }
526     }
527   }
528 
529   VLOG(2) << "After instruction fusion:";
530   XLA_VLOG_LINES(2, module->ToString());
531 
532   return changed;
533 }
534 
AddFusionInstruction(HloInstruction * producer,HloInstruction * consumer)535 HloInstruction* InstructionFusion::AddFusionInstruction(
536     HloInstruction* producer, HloInstruction* consumer) {
537   HloInstruction* fusion_instruction;
538   auto kind = ChooseKind(producer, consumer);
539   if (consumer->opcode() == HloOpcode::kFusion) {
540     fusion_instruction = consumer;
541     if (kind != fusion_instruction->fusion_kind()) {
542       fusion_instruction->set_fusion_kind(kind);
543     }
544   } else {
545     fusion_instruction = computation_->AddInstruction(
546         HloInstruction::CreateFusion(consumer->shape(), kind, consumer));
547     TF_CHECK_OK(computation_->ReplaceInstruction(consumer, fusion_instruction));
548   }
549   return fusion_instruction;
550 }
551 
Fuse(HloInstruction * producer,HloInstruction * consumer)552 HloInstruction* InstructionFusion::Fuse(HloInstruction* producer,
553                                         HloInstruction* consumer) {
554   VLOG(2) << "Fusing " << producer->ToString() << " into "
555           << consumer->ToString();
556   HloInstruction* fusion_instruction = AddFusionInstruction(producer, consumer);
557   fusion_instruction->FuseInstruction(producer);
558   return fusion_instruction;
559 }
560 
FuseIntoMultiOutput(HloInstruction * producer,HloInstruction * consumer)561 HloInstruction* InstructionFusion::FuseIntoMultiOutput(
562     HloInstruction* producer, HloInstruction* consumer) {
563   VLOG(2) << "Multi-output fusing " << producer->ToString() << " into "
564           << consumer->ToString();
565   HloInstruction* fusion_instruction = AddFusionInstruction(producer, consumer);
566   fusion_instruction->FuseInstructionIntoMultiOutput(producer);
567   return fusion_instruction;
568 }
569 
MultiOutputFusionCreatesCycle(HloInstruction * producer,HloInstruction * consumer)570 bool InstructionFusion::MultiOutputFusionCreatesCycle(
571     HloInstruction* producer, HloInstruction* consumer) {
572   absl::flat_hash_set<int> operands;
573   for (const HloInstruction* operand : consumer->operands()) {
574     if (operand == producer) {
575       continue;
576     }
577 
578     // If the reachability map already contains the producer and the operand of
579     // the consumer, and the producer can reach the operand, then we know for
580     // sure MultiOutputFusion would create a cycle. If not, we need to do a DFS
581     // traversal of the computation to verify that this multioutput fusion would
582     // not create a cycle.
583     if (reachability_->IsPresent(producer) &&
584         reachability_->IsPresent(operand) &&
585         reachability_->IsReachable(producer, operand)) {
586       return true;
587     }
588     operands.insert(operand->unique_id());
589   }
590 
591   // Do a DFS on the producer to see if any of the other consumer operands are
592   // reachable in the current state of the graph.
593   std::vector<HloInstruction*> worklist = producer->users();
594   absl::flat_hash_set<int> visits;
595   while (!worklist.empty()) {
596     const HloInstruction* user = worklist.back();
597     worklist.pop_back();
598     if (operands.count(user->unique_id()) != 0) {
599       return true;
600     }
601     if (visits.count(user->unique_id()) == 0) {
602       visits.insert(user->unique_id());
603       worklist.insert(worklist.end(), user->users().begin(),
604                       user->users().end());
605     }
606   }
607   return false;
608 }
609 
ShouldFuse(HloInstruction * consumer,int64 operand_index)610 bool InstructionFusion::ShouldFuse(HloInstruction* consumer,
611                                    int64 operand_index) {
612   HloInstruction* producer = consumer->mutable_operand(operand_index);
613 
614   // Cost condition: don't duplicate expensive instructions.
615   if (FusionWouldDuplicate(*producer, *consumer) &&
616       (!may_duplicate_ || is_expensive_(*producer)) &&
617       !IsAlwaysDuplicable(*producer)) {
618     return false;
619   }
620 
621   if (consumer->opcode() == HloOpcode::kFusion &&
622       consumer->fusion_kind() != HloInstruction::FusionKind::kLoop &&
623       consumer->fusion_kind() != HloInstruction::FusionKind::kInput &&
624       consumer->fusion_kind() != HloInstruction::FusionKind::kOutput) {
625     return false;
626   }
627 
628   if (producer->CouldBeBitcast() &&
629       // We can't fuse parameters anyhow, so we leave the user unfused to become
630       // a bitcast. If the operand is not a parameter, we would break a
631       // potential fusion to make it a bitcast, which is not so clear a win.
632       producer->operand(0)->opcode() == HloOpcode::kParameter) {
633     return false;
634   }
635 
636   return true;
637 }
638 
ChooseKind(const HloInstruction * producer,const HloInstruction * consumer)639 HloInstruction::FusionKind InstructionFusion::ChooseKind(
640     const HloInstruction* producer, const HloInstruction* consumer) {
641   return HloInstruction::FusionKind::kLoop;
642 }
643 
644 }  // namespace xla
645