• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/instruction_fusion.h"
17 
18 #include <algorithm>
19 #include <functional>
20 #include <list>
21 #include <memory>
22 #include <numeric>
23 #include <string>
24 #include <utility>
25 #include <vector>
26 
27 #include "absl/algorithm/container.h"
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/container/flat_hash_set.h"
30 #include "tensorflow/compiler/xla/debug_options_flags.h"
31 #include "tensorflow/compiler/xla/map_util.h"
32 #include "tensorflow/compiler/xla/service/fusion_queue.h"
33 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
34 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
35 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
36 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
37 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
38 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
39 #include "tensorflow/core/lib/core/errors.h"
40 #include "tensorflow/core/platform/logging.h"
41 
42 namespace xla {
43 namespace {
44 
45 // These nodes can always be duplicated into consumers, even if
46 // InstructionFusion::may_duplicate_ is false.
47 //
48 // In general these should be nodes that get *cheaper* the more they're
49 // duplicated (and fused into consumers).
50 //
51 // TODO(jlebar): Duplicating instructions when we have a variable called "may
52 // duplicate" that's equal to false is not pretty.
IsAlwaysDuplicable(const HloInstruction & instruction)53 bool IsAlwaysDuplicable(const HloInstruction& instruction) {
54   // We are always willing to duplicate a widening type-conversion instruction
55   // if it means we can fuse the convert into a consumer.  This allows the
56   // consumer to read less memory, which is almost always a performance win.
57   return instruction.opcode() == HloOpcode::kConvert &&
58          ShapeUtil::ByteSizeOf(instruction.operand(0)->shape()) <
59              ShapeUtil::ByteSizeOf(instruction.shape());
60 }
61 }  // namespace
62 
IsExpensive(const HloInstruction & instruction)63 /*static*/ bool InstructionFusion::IsExpensive(
64     const HloInstruction& instruction) {
65   namespace m = match;
66 
67   switch (instruction.opcode()) {
68     // Cheap instructions.
69     case HloOpcode::kAdd:
70     case HloOpcode::kAnd:
71     case HloOpcode::kBitcast:
72     case HloOpcode::kBitcastConvert:
73     case HloOpcode::kBroadcast:
74     case HloOpcode::kCeil:
75     case HloOpcode::kClamp:
76     case HloOpcode::kClz:
77     case HloOpcode::kCompare:
78     case HloOpcode::kComplex:
79     case HloOpcode::kConcatenate:
80     case HloOpcode::kConstant:
81     case HloOpcode::kConvert:
82     case HloOpcode::kCopy:
83     case HloOpcode::kCopyDone:
84     case HloOpcode::kCopyStart:
85     case HloOpcode::kDynamicSlice:
86     case HloOpcode::kDynamicUpdateSlice:
87     case HloOpcode::kFloor:
88     case HloOpcode::kGetTupleElement:
89     case HloOpcode::kImag:
90     case HloOpcode::kInfeed:
91     case HloOpcode::kIota:
92     case HloOpcode::kIsFinite:
93     case HloOpcode::kMaximum:
94     case HloOpcode::kMinimum:
95     case HloOpcode::kMultiply:
96     case HloOpcode::kNegate:
97     case HloOpcode::kNot:
98     case HloOpcode::kOptimizationBarrier:
99     case HloOpcode::kOr:
100     case HloOpcode::kXor:
101     case HloOpcode::kOutfeed:
102     case HloOpcode::kPad:
103     case HloOpcode::kPartitionId:
104     case HloOpcode::kPopulationCount:
105     case HloOpcode::kReal:
106     case HloOpcode::kReducePrecision:
107     case HloOpcode::kReplicaId:
108     case HloOpcode::kReshape:
109     case HloOpcode::kDynamicReshape:
110     case HloOpcode::kReverse:
111     case HloOpcode::kRoundNearestAfz:
112     case HloOpcode::kRoundNearestEven:
113     case HloOpcode::kSelect:
114     case HloOpcode::kShiftLeft:
115     case HloOpcode::kShiftRightArithmetic:
116     case HloOpcode::kShiftRightLogical:
117     case HloOpcode::kSlice:
118     case HloOpcode::kSubtract:
119     case HloOpcode::kTranspose:
120     case HloOpcode::kTuple:
121       return false;
122 
123     // Cheap instructions for reals, but expensive for complex.
124     case HloOpcode::kAbs:
125     case HloOpcode::kCos:
126     case HloOpcode::kSign:
127     case HloOpcode::kSin:
128       return ShapeUtil::ElementIsComplex(instruction.shape());
129 
130     // We say that integer div/mod by a constant is cheap because it gets
131     // compiled down to multiplies and shifts, and we consider those to be
132     // cheap.
133     case HloOpcode::kDivide:
134     case HloOpcode::kRemainder:
135       return !ShapeUtil::ElementIsIntegral(instruction.shape()) ||
136              !Match(instruction.operand(1),
137                     m::AnyOf<const HloInstruction>(
138                         m::ConstantEffectiveScalar(),
139                         m::Broadcast(m::ConstantEffectiveScalar())));
140 
141     // Expensive instructions or unusual instructions for which fusion is
142     // nonsensical.
143     case HloOpcode::kAddDependency:
144     case HloOpcode::kAfterAll:
145     case HloOpcode::kAtan2:
146     case HloOpcode::kAsyncStart:
147     case HloOpcode::kAsyncUpdate:
148     case HloOpcode::kAsyncDone:
149     case HloOpcode::kBatchNormGrad:
150     case HloOpcode::kBatchNormInference:
151     case HloOpcode::kBatchNormTraining:
152     case HloOpcode::kCall:
153     case HloOpcode::kCholesky:
154     case HloOpcode::kConditional:
155     case HloOpcode::kConvolution:
156     case HloOpcode::kAllGather:
157     case HloOpcode::kAllGatherStart:
158     case HloOpcode::kAllGatherDone:
159     case HloOpcode::kAllReduce:
160     case HloOpcode::kReduceScatter:
161     case HloOpcode::kAllReduceStart:
162     case HloOpcode::kAllReduceDone:
163     case HloOpcode::kAllToAll:
164     case HloOpcode::kCollectivePermute:
165     case HloOpcode::kCollectivePermuteDone:
166     case HloOpcode::kCollectivePermuteStart:
167     case HloOpcode::kCustomCall:
168     case HloOpcode::kDomain:
169     case HloOpcode::kDot:
170     case HloOpcode::kExp:
171     case HloOpcode::kExpm1:
172     case HloOpcode::kFft:
173     case HloOpcode::kFusion:
174     case HloOpcode::kGather:
175     case HloOpcode::kLog:
176     case HloOpcode::kLog1p:
177     case HloOpcode::kLogistic:
178     case HloOpcode::kMap:
179     case HloOpcode::kParameter:
180     case HloOpcode::kPower:
181     case HloOpcode::kRecv:
182     case HloOpcode::kRecvDone:
183     case HloOpcode::kReduce:
184     case HloOpcode::kReduceWindow:
185     case HloOpcode::kRng:
186     case HloOpcode::kRngGetAndUpdateState:
187     case HloOpcode::kRngBitGenerator:
188     case HloOpcode::kRsqrt:
189     case HloOpcode::kScatter:
190     case HloOpcode::kSelectAndScatter:
191     case HloOpcode::kSend:
192     case HloOpcode::kSendDone:
193     case HloOpcode::kSort:
194     case HloOpcode::kSqrt:
195     case HloOpcode::kCbrt:
196     case HloOpcode::kTanh:
197     case HloOpcode::kTriangularSolve:
198     case HloOpcode::kWhile:
199     case HloOpcode::kGetDimensionSize:
200     case HloOpcode::kSetDimensionSize:
201       return true;
202   }
203 
204   return false;
205 }
206 
207 // An "effectively at most unary" operation is one that has at most one "large"
208 // input with the others being negligible in terms of memory usage.
209 // We use "has a smaller true rank than the output" as a heuristic
210 // for "negligible" memory usage.
EffectivelyAtMostUnary(HloInstruction * hlo)211 bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) {
212   int64_t output_rank = 0;
213   ShapeUtil::ForEachSubshape(
214       hlo->shape(),
215       [&output_rank](const Shape& subshape, const ShapeIndex& shape_index) {
216         if (subshape.IsArray()) {
217           output_rank = std::max(output_rank, ShapeUtil::TrueRank(subshape));
218         }
219       });
220   return absl::c_count_if(
221              hlo->operands(), [output_rank](HloInstruction* operand) {
222                if (operand->opcode() == HloOpcode::kBroadcast ||
223                    operand->opcode() == HloOpcode::kIota) {
224                  return false;
225                }
226                if (operand->opcode() == HloOpcode::kConstant &&
227                    ShapeUtil::IsEffectiveScalar(operand->shape())) {
228                  return false;
229                }
230                return ShapeUtil::TrueRank(operand->shape()) >= output_rank;
231              }) <= 1;
232 }
233 
CanFuseOnAllPaths(HloInstruction * producer,HloInstruction * consumer,const HloInstructionSet & do_not_fuse,const HloReachabilityMap & reachability,absl::flat_hash_map<std::pair<HloInstruction *,HloInstruction * >,bool> * result_cache)234 bool InstructionFusion::CanFuseOnAllPaths(
235     HloInstruction* producer, HloInstruction* consumer,
236     const HloInstructionSet& do_not_fuse,
237     const HloReachabilityMap& reachability,
238     absl::flat_hash_map<std::pair<HloInstruction*, HloInstruction*>, bool>*
239         result_cache) {
240   if (consumer == producer) {
241     return true;
242   }
243   if (!consumer->IsFusible()) {
244     return false;
245   }
246   auto cache_it = result_cache->find(std::make_pair(producer, consumer));
247   if (cache_it != result_cache->end()) {
248     return cache_it->second;
249   }
250   bool result = true;
251   for (int64_t i = 0, e = consumer->operand_count(); i < e; ++i) {
252     auto* consumer_operand = consumer->mutable_operand(i);
253     // If the operand is not on a path to the producer, it doesn't matter
254     // whether it's fusible.
255     if (!reachability.IsReachable(producer, consumer_operand)) {
256       continue;
257     }
258     if (do_not_fuse.count(consumer_operand) > 0 || !ShouldFuse(consumer, i)) {
259       result = false;
260       break;
261     }
262     // The producer is reachable from consumer_operand which means we need
263     // to be able to fuse consumer_operand into consumer in order for
264     // producer to be fusible into consumer on all paths.
265     // Perform the recursive step: make sure producer can be fused into
266     // consumer_operand on all paths.
267     if (!CanFuseOnAllPaths(producer, consumer_operand, do_not_fuse,
268                            reachability, result_cache)) {
269       result = false;
270       break;
271     }
272   }
273   result_cache->emplace(std::make_pair(producer, consumer), result);
274   return result;
275 }
276 
277 InstructionFusion::HloInstructionSet
ComputeGloballyUnfusible(absl::Span<HloInstruction * const> post_order,const HloReachabilityMap & reachability)278 InstructionFusion::ComputeGloballyUnfusible(
279     absl::Span<HloInstruction* const> post_order,
280     const HloReachabilityMap& reachability) {
281   // Forbid fusion of producers that:
282   // a) Need to be duplicated, unless they can be fused into all consumers
283   //    via all paths.
284   // b) Are more than unary, that is, fusing them would likely lead to an
285   //    increase in memory bandwidth use.
286   //
287   // Note that if we allow fusion by these global rules, we may still forbid
288   // fusing operations that require duplication later depending on
289   // is_expensive_().
290   HloInstructionSet do_not_duplicate;
291   absl::flat_hash_map<std::pair<HloInstruction*, HloInstruction*>, bool>
292       can_fuse_on_all_paths_result_cache;
293   for (auto it = post_order.rbegin(); it != post_order.rend(); ++it) {
294     HloInstruction* producer = *it;
295     // If the producer is effectively not more than unary, duplicating it
296     // will not increase the number of relevant inputs read, as the fusion
297     // node will only need to read at most 1 relevant input (the input of
298     // the producer). In that case, we do not forbid fusion of the operation
299     // here.
300     if (EffectivelyAtMostUnary(producer)) {
301       continue;
302     }
303 
304     // If the total size of the inputs is less than or equal to the total size
305     // of the outputs for the producer then duplicating it won't increase the
306     // memory traffic. In that case, we do not forbid fusion of the operation
307     // here.
308     auto total_size = [](const Shape& shape) {
309       int64_t size = 0;
310       ShapeUtil::ForEachSubshape(
311           shape, [&size](const Shape& subshape, const ShapeIndex& shape_index) {
312             if (subshape.IsArray()) {
313               size += ShapeUtil::ElementsIn(subshape);
314             }
315           });
316       return size;
317     };
318     int64_t operands_size = 0;
319     for (const HloInstruction* op : producer->unique_operands()) {
320       operands_size += total_size(op->shape());
321     }
322     if (operands_size <= total_size(producer->shape())) {
323       continue;
324     }
325 
326     // Otherwise we will forbid fusing the op unless we can fuse it into
327     // all of its consumers on all paths.
328     //
329     // That means, that for:
330     // A --> B (fusible)
331     //   \-> C (non-fusible)
332     // A will be not allowed to be fused into B, as it cannot be fused into C.
333     //
334     // Similarly, for:
335     // A -------------> B
336     //   \-> C -> D -/
337     // If:
338     // - A is fusible into B and C, and D is fusible into B
339     // - C is *not* fusible into D
340     // A will be not allowed to be fused into B, as it cannot be fused via
341     // all paths.
342     if (producer->IsFusible() &&
343         absl::c_all_of(producer->users(), [&](HloInstruction* consumer) {
344           return CanFuseOnAllPaths(producer, consumer, do_not_duplicate,
345                                    reachability,
346                                    &can_fuse_on_all_paths_result_cache);
347         })) {
348       continue;
349     }
350     do_not_duplicate.insert(producer);
351   }
352 
353   return do_not_duplicate;
354 }
355 
356 namespace {
357 
358 // A FusionQueue that uses reverse post order.
359 //
360 // We want to be able to remove arbitrary instructions from the post order and
361 // also compare positions of instructions in the post order. To make this
362 // possible, create vector of instructions in post order and create a map from
363 // HloInstruction* to the instruction's index in the vector. An instruction is
364 // "removed" from the vector by setting it's element to nullptr.
365 class ReversePostOrderFusionQueue : public FusionQueue {
366  public:
ReversePostOrderFusionQueue(HloComputation * computation)367   explicit ReversePostOrderFusionQueue(HloComputation* computation) {
368     post_order_ = computation->MakeInstructionPostOrder();
369 
370     for (size_t i = 0; i < post_order_.size(); ++i) {
371       InsertOrDie(&post_order_index_, post_order_[i], i);
372     }
373   }
374 
375   std::pair<HloInstruction*, std::vector<int64_t>>
DequeueNextInstructionAndOperandsToFuseInOrder()376   DequeueNextInstructionAndOperandsToFuseInOrder() override {
377     // Instructions are "removed" from the post order by nulling out the element
378     // in the vector, so if the pointer is null, continue to the next
379     // instruction in the sort.
380     while (!post_order_.empty() && post_order_.back() == nullptr) {
381       post_order_.pop_back();
382     }
383     if (post_order_.empty()) {
384       return std::pair<HloInstruction*, std::vector<int64_t>>{nullptr, {}};
385     }
386     // We want to iterate in reverse post order, so remove from the back of the
387     // vector.
388     HloInstruction* instruction = post_order_.back();
389     post_order_.pop_back();
390 
391     CHECK(instruction != nullptr);
392     // Remove instruction from the index map to ensure the vector and map stay
393     // consistent.
394     post_order_index_.erase(instruction);
395 
396     // Consider each operand of this instruction for fusion into this
397     // instruction. We want to consider the operands in a particular order to
398     // avoid creating duplicate instruction clones in the fusion instruction.
399     // For example, consider the following expression:
400     //
401     //   A = ...
402     //   B = op(A)
403     //   C = op(A, B)
404     //
405     // If we are considering the operands of C for fusion into C. We might
406     // fuse A or B first. If we fuse A first, we get:
407     //
408     //   A = ...
409     //   B = op(A)
410     //   C_fusion = { A' = ...
411     //                C' = op(A', B) }
412     //
413     // Where A' and C' are clones of A and C, respectively. Now only B is an
414     // operand of the fusion instruction C_fusion, so then we fuse B:
415     //
416     //   A = ...
417     //   B = op(A)
418     //   C_fusion = { A' = ...
419     //                B' = op(A)
420     //                C' = op(A', B') }
421     //
422     // Now A is an operand of C_fusion again, so we then fuse A (again!):
423     //
424     //   A = ...
425     //   B = op(A)
426     //   C_fusion = { A' = ...
427     //                A" = ..
428     //                B' = op(A")
429     //                C' = op(A', B') }
430     //
431     // We prevent this duplication by considering the operands in the order
432     // they appear int the queue. In the example, this ensures that B will be
433     // considered before A.
434     //
435     // We store the original indices of the operands to pass to ShouldFuse.
436     std::vector<int64_t> sorted_operand_numbers;
437     sorted_operand_numbers.reserve(instruction->operands().size());
438     for (int i = 0; i < instruction->operands().size(); ++i) {
439       // This will happen if we have two possible instructions to fuse the
440       // same operand into; once the operand is fused into one instruction,
441       // the other instruction will get a new get-tuple-element as its
442       // operand, which is not in the queue.
443       // TODO(tjoerg): Look into fusing past these multi-output fuse points.
444       if (!ContainsKey(post_order_index_, instruction->mutable_operand(i))) {
445         continue;
446       }
447       sorted_operand_numbers.push_back(i);
448     }
449     absl::c_sort(sorted_operand_numbers, [&](int64_t i, int64_t j) {
450       // Instructions with higher priority in the queue come first.
451       return (FindOrDie(post_order_index_, instruction->mutable_operand(i)) >
452               FindOrDie(post_order_index_, instruction->mutable_operand(j)));
453     });
454     return std::make_pair(instruction, sorted_operand_numbers);
455   }
456 
OnFusingInstruction(HloInstruction * fusion,HloInstruction * original_producer,HloInstruction * original_consumer)457   void OnFusingInstruction(HloInstruction* fusion,
458                            HloInstruction* original_producer,
459                            HloInstruction* original_consumer) override {
460     // Fusing an instruction into a fusion instruction can change the operand
461     // set of the fusion instruction. For simplicity just re-enqueue the
462     // instruction and reconsider it for further fusion in the next iteration.
463     InsertOrDie(&post_order_index_, fusion, post_order_.size());
464     post_order_.push_back(fusion);
465   }
466 
RemoveInstruction(HloInstruction * instruction)467   void RemoveInstruction(HloInstruction* instruction) override {
468     post_order_[FindOrDie(post_order_index_, instruction)] = nullptr;
469     post_order_index_.erase(instruction);
470   }
471 
FusionConfiguration()472   const std::vector<bool>* FusionConfiguration() override {
473     return &fusion_config_;
474   }
475 
476  private:
477   std::vector<HloInstruction*> post_order_;
478   absl::flat_hash_map<HloInstruction*, int> post_order_index_;
479   std::vector<bool> fusion_config_;
480 };
481 
482 }  // namespace
483 
GetFusionComputations(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)484 std::vector<HloComputation*> InstructionFusion::GetFusionComputations(
485     HloModule* module,
486     const absl::flat_hash_set<absl::string_view>& execution_threads) {
487   // Use sorted computations because fusion configuration is order-sensitive.
488   return module->MakeNonfusionComputationsSorted(execution_threads);
489 }
490 
GetFusionQueue(HloComputation * computation)491 std::unique_ptr<FusionQueue> InstructionFusion::GetFusionQueue(
492     HloComputation* computation) {
493   return std::make_unique<ReversePostOrderFusionQueue>(computation);
494 }
495 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)496 StatusOr<bool> InstructionFusion::Run(
497     HloModule* module,
498     const absl::flat_hash_set<absl::string_view>& execution_threads) {
499   bool changed = false;
500   int64_t fuse_count = 0;
501   std::vector<std::vector<bool>>* fusion_config = nullptr;
502   HloModuleConfig module_config;
503   if (config_collection_mode_ != FusionConfigCollection::kOff) {
504     module_config = module->config();
505     fusion_config = module_config.mutable_fusion_config();
506     fusion_config->clear();
507   }
508 
509   bool dump_fusion =
510       module->config().debug_options().xla_dump_fusion_visualization();
511 
512   for (auto* computation : GetFusionComputations(module, execution_threads)) {
513     CHECK(!computation->IsFusionComputation());
514     std::unique_ptr<HloReachabilityMap> reachability =
515         HloReachabilityMap::Build(computation);
516 
517     HloInstructionSet do_not_duplicate;
518     // If we allow duplications, we need to compute which instructions we do not
519     // want to duplicate based on a global analysis of the graph.
520     if (may_duplicate_) {
521       do_not_duplicate = ComputeGloballyUnfusible(
522           computation->MakeInstructionPostOrder(), *reachability);
523     }
524     auto fusion_queue = GetFusionQueue(computation);
525 
526     // Instruction fusion effectively fuses edges in the computation graph
527     // (producer instruction -> consumer instruction) so we iterate over all
528     // edges. When we fuse an edge, we create a copy of the producer inside the
529     // fusion instruction.
530     while (true) {
531       std::pair<HloInstruction*, std::vector<int64_t>> next_entry =
532           fusion_queue->DequeueNextInstructionAndOperandsToFuseInOrder();
533       HloInstruction* instruction = next_entry.first;
534       if (instruction == nullptr) {
535         break;
536       }
537 
538       if (!instruction->IsFusible() &&
539           instruction->opcode() != HloOpcode::kFusion) {
540         continue;
541       }
542 
543       std::vector<int64_t>& sorted_operand_numbers = next_entry.second;
544 
545       for (int64_t i : sorted_operand_numbers) {
546         HloInstruction* operand = instruction->mutable_operand(i);
547         VLOG(5) << "Considering fusion of: " << instruction->ToString()
548                 << " with operand " << operand->name();
549 
550         if (!operand->IsFusible()) {
551           VLOG(3) << "Operand (" << operand->ToString() << ") is not fusible";
552           continue;
553         }
554 
555         // Consumes a unit of compiler fuel and returns true if we should
556         // continue with the transformation.
557         auto consume_fuel = [&] {
558           return ConsumeFuel(name(), /*ran_out_of_fuel_msg=*/[&] {
559             return absl::StrFormat("Not fusing operand %d of %s, namely, %s", i,
560                                    instruction->ToString(),
561                                    operand->ToString());
562           });
563         };
564 
565         HloInstruction* fusion_instruction = nullptr;
566 
567         FusionDecision should_fuse(do_not_duplicate.count(operand) == 0,
568                                    "operand can not be duplicated");
569 
570         // Try "regular" fusion if the operand may be duplicated. Otherwise,
571         // perform multi-output fusion, unless this creates a cycle.
572         if (should_fuse) {
573           should_fuse = ShouldFuse(instruction, i);
574           if (should_fuse && consume_fuel()) {
575             if (dump_fusion) {
576               RegisterFusionState(
577                   *computation,
578                   absl::StrCat("About to fuse |", operand->name(), "| into |",
579                                instruction->name(),
580                                "| inside InstructionFusion with may_duplicate=",
581                                may_duplicate_),
582                   /*consumer=*/*instruction,
583                   /*producer=*/operand);
584             }
585 
586             fusion_queue->PreFusion(operand, instruction);
587             fusion_instruction = Fuse(operand, instruction, computation);
588           }
589         }
590 
591         if (!should_fuse) {
592           FusionDecision can_fuse_mof =
593               ShouldFuseIntoMultiOutput(instruction, i);
594           if (can_fuse_mof) {
595             can_fuse_mof = can_fuse_mof.And(
596                 FusionDecision{!MultiOutputFusionCreatesCycle(
597                                    operand, instruction, *reachability),
598                                "multi-output fusion creates a cycle"});
599           }
600           if (can_fuse_mof) {
601             if (consume_fuel()) {
602               if (dump_fusion) {
603                 RegisterFusionState(
604                     *computation,
605                     absl::StrCat(
606                         "About to MOF-fuse |", operand->name(), "| into |",
607                         instruction->name(),
608                         "| inside InstructionFusion with may_duplicate=",
609                         may_duplicate_),
610                     /*consumer=*/*instruction, /*producer=*/operand);
611               }
612 
613               fusion_queue->PreFusion(operand, instruction);
614               fusion_instruction =
615                   FuseIntoMultiOutput(operand, instruction, computation);
616             }
617           }
618           should_fuse = should_fuse.Or(can_fuse_mof);
619         }
620 
621         if (fusion_instruction == nullptr) {
622           CHECK(!should_fuse.CanFuse());
623           if (dump_fusion) {
624             VLOG(2) << "Not fusing " << operand->ToShortString() << "| into |"
625                     << instruction->ToShortString() << "| as "
626                     << should_fuse.Explain();
627 
628             // Readability optimizations: lack of fusion for tuple accesses
629             // generates a lot of noise.
630             if (operand->opcode() != HloOpcode::kGetTupleElement &&
631                 instruction->opcode() != HloOpcode::kGetTupleElement) {
632               RegisterFusionState(*computation,
633                                   absl::StrCat("Not fusing |", operand->name(),
634                                                "| into |", instruction->name(),
635                                                "| as ", should_fuse.Explain()),
636                                   /*consumer=*/*instruction,
637                                   /*producer=*/operand);
638             }
639           }
640 
641           fusion_queue->NotFusingInstruction(operand, instruction);
642           continue;
643         }
644 
645         // Saving name to use after the instruction is removed.
646         std::string producer_name = operand->name();
647         fusion_queue->OnFusingInstruction(fusion_instruction, operand,
648                                           instruction);
649         changed = true;
650         ++fuse_count;
651 
652         if (operand->user_count() == 0) {
653           do_not_duplicate.erase(operand);
654           // Operand is now dead. Remove from queue.
655           fusion_queue->RemoveInstruction(operand);
656           // Remove from computation.
657           TF_RETURN_IF_ERROR(computation->RemoveInstruction(operand));
658         }
659 
660         if (dump_fusion) {
661           RegisterFusionState(
662               *computation,
663               absl::StrCat("Fused |", producer_name, "| into |",
664                            fusion_instruction->name(),
665                            "| inside InstructionFusion with may_duplicate=",
666                            may_duplicate_),
667               *fusion_instruction);
668         }
669 
670         if (fusion_instruction != instruction) {
671           do_not_duplicate.erase(instruction);
672         }
673         break;
674       }
675     }
676 
677     if (config_collection_mode_ != FusionConfigCollection::kOff) {
678       const std::vector<bool>* comp_fusion_config =
679           fusion_queue->FusionConfiguration();
680       if (comp_fusion_config && !comp_fusion_config->empty()) {
681         fusion_config->push_back(*comp_fusion_config);
682       }
683     }
684   }
685 
686   if (config_collection_mode_ != FusionConfigCollection::kOff) {
687     int64_t fused_count = 0;
688     for (auto& config_per_computation : *fusion_config) {
689       for (auto edge : config_per_computation) {
690         if (edge) {
691           ++fused_count;
692         }
693       }
694     }
695     VLOG(1) << "There are " << fused_count << " fused bits that cause "
696             << fuse_count << " fusion actions.";
697     module->set_config(module_config);
698   }
699 
700   VLOG(1) << "Fusion count: " << fuse_count;
701 
702   return changed;
703 }
704 
AddFusionInstruction(HloInstruction * producer,HloInstruction * consumer,HloComputation * computation)705 HloInstruction* InstructionFusion::AddFusionInstruction(
706     HloInstruction* producer, HloInstruction* consumer,
707     HloComputation* computation) {
708   HloInstruction* fusion_instruction;
709   auto kind = ChooseKind(producer, consumer);
710   if (consumer->opcode() == HloOpcode::kFusion) {
711     fusion_instruction = consumer;
712     if (kind != fusion_instruction->fusion_kind()) {
713       fusion_instruction->set_fusion_kind(kind);
714     }
715   } else {
716     fusion_instruction = computation->AddInstruction(
717         HloInstruction::CreateFusion(consumer->shape(), kind, consumer));
718     TF_CHECK_OK(computation->ReplaceInstruction(consumer, fusion_instruction));
719   }
720   fusion_instruction->set_called_computations_execution_thread(
721       computation->execution_thread(),
722       /*skip_async_execution_thread_overwrite=*/false);
723   return fusion_instruction;
724 }
725 
FuseInstruction(HloInstruction * fusion_instruction,HloInstruction * producer)726 HloInstruction* InstructionFusion::FuseInstruction(
727     HloInstruction* fusion_instruction, HloInstruction* producer) {
728   return fusion_instruction->FuseInstruction(producer);
729 }
730 
UpdateReusedOperandsForFusion(HloInstruction * producer,HloInstruction * fusion_instruction)731 void InstructionFusion::UpdateReusedOperandsForFusion(
732     HloInstruction* producer, HloInstruction* fusion_instruction) {
733   // Find or compute the existing fusion reused operands. Note these reflect the
734   // state *before* the current fusion has taken place, although if we have
735   // replaced the consumer with a new single-element fusion, we will compute
736   // the new single-element fusion's reused operands here.
737   absl::flat_hash_set<const HloInstruction*>& fusion_reused_operands =
738       ReusedOperandsOf(fusion_instruction);
739 
740   // If the producer is reused, replace it with its operands.
741   if (fusion_reused_operands.erase(producer)) {
742     fusion_reused_operands.insert(producer->operands().begin(),
743                                   producer->operands().end());
744   } else {
745     const absl::flat_hash_set<const HloInstruction*>& producer_reused_operands =
746         ReusedOperandsOf(producer);
747     // Otherwise add the producer's reused operands to the set.
748     fusion_reused_operands.insert(producer_reused_operands.begin(),
749                                   producer_reused_operands.end());
750   }
751 }
752 
Fuse(HloInstruction * producer,HloInstruction * consumer,HloComputation * computation)753 HloInstruction* InstructionFusion::Fuse(HloInstruction* producer,
754                                         HloInstruction* consumer,
755                                         HloComputation* computation) {
756   VLOG(2) << "Fusing " << producer->ToString() << " into "
757           << consumer->ToString();
758   HloInstruction* fusion_instruction =
759       AddFusionInstruction(producer, consumer, computation);
760   UpdateReusedOperandsForFusion(producer, fusion_instruction);
761   FuseInstruction(fusion_instruction, producer);
762   if (fusion_instruction != producer && fusion_instruction != consumer) {
763     VLOG(2) << "       created new fusion: " << fusion_instruction->ToString();
764   }
765   return fusion_instruction;
766 }
767 
FuseIntoMultiOutput(HloInstruction * producer,HloInstruction * consumer,HloComputation * computation)768 HloInstruction* InstructionFusion::FuseIntoMultiOutput(
769     HloInstruction* producer, HloInstruction* consumer,
770     HloComputation* computation) {
771   VLOG(2) << "Multi-output fusing " << producer->ToString() << " into "
772           << consumer->ToString();
773   HloInstruction* fusion_instruction =
774       AddFusionInstruction(producer, consumer, computation);
775   UpdateReusedOperandsForFusion(producer, fusion_instruction);
776   fusion_instruction->FuseInstructionIntoMultiOutput(producer);
777   return fusion_instruction;
778 }
779 
MultiOutputFusionCreatesCycle(HloInstruction * producer,HloInstruction * consumer,const HloReachabilityMap & reachability)780 bool InstructionFusion::MultiOutputFusionCreatesCycle(
781     HloInstruction* producer, HloInstruction* consumer,
782     const HloReachabilityMap& reachability) {
783   absl::flat_hash_set<int> operands;
784   for (const HloInstruction* operand : consumer->operands()) {
785     if (operand == producer) {
786       continue;
787     }
788 
789     // If the reachability map already contains the producer and the operand of
790     // the consumer, and the producer can reach the operand, then we know for
791     // sure MultiOutputFusion would create a cycle. If not, we need to do a DFS
792     // traversal of the computation to verify that this multioutput fusion would
793     // not create a cycle.
794     if (reachability.IsPresent(producer) && reachability.IsPresent(operand) &&
795         reachability.IsReachable(producer, operand)) {
796       return true;
797     }
798     operands.insert(operand->unique_id());
799   }
800 
801   // Do a DFS on the producer to see if any of the other consumer operands are
802   // reachable in the current state of the graph.
803   std::vector<HloInstruction*> worklist = producer->users();
804   absl::flat_hash_set<int> visits;
805   while (!worklist.empty()) {
806     const HloInstruction* user = worklist.back();
807     worklist.pop_back();
808     if (operands.count(user->unique_id()) != 0) {
809       return true;
810     }
811     if (visits.count(user->unique_id()) == 0) {
812       visits.insert(user->unique_id());
813       worklist.insert(worklist.end(), user->users().begin(),
814                       user->users().end());
815     }
816   }
817   return false;
818 }
819 
820 namespace {
821 
822 // Extracts instruction from the fusion that satisfies filter. If no or multiple
823 // instructions in the fusion satisfy filter, returns nullptr.
ExtractInstruction(const HloInstruction * hlo,const HloPredicate & filter)824 const HloInstruction* ExtractInstruction(const HloInstruction* hlo,
825                                          const HloPredicate& filter) {
826   if (filter(hlo)) {
827     return hlo;
828   }
829   if (hlo->opcode() != HloOpcode::kFusion) {
830     return nullptr;
831   }
832   const HloInstruction* match = nullptr;
833   for (HloInstruction* inst :
834        hlo->fused_instructions_computation()->instructions()) {
835     if (filter(inst)) {
836       if (match == nullptr) {
837         match = inst;
838       } else {
839         return nullptr;
840       }
841     }
842   }
843   return match;
844 }
845 
ExtractInstruction(const HloInstruction * hlo,HloOpcode opcode)846 const HloInstruction* ExtractInstruction(const HloInstruction* hlo,
847                                          HloOpcode opcode) {
848   return ExtractInstruction(hlo, [opcode](const HloInstruction* inst) {
849     return inst->opcode() == opcode;
850   });
851 }
852 
853 // Returns true if fusing a slice or dynamic slice in producer into a dynamic
854 // update slice fusion in consumer is safe. It is not safe to fuse the slice and
855 // DUS when fusing will cause another non-elementwise op to share operands with
856 // the DUS in-place buffer.
IsSafeToFuseSliceIntoDusFusion(const HloInstruction * producer,const HloInstruction * consumer)857 bool IsSafeToFuseSliceIntoDusFusion(const HloInstruction* producer,
858                                     const HloInstruction* consumer) {
859   CHECK_EQ(consumer->opcode(), HloOpcode::kFusion);
860   const HloInstruction* dus =
861       ExtractInstruction(consumer, HloOpcode::kDynamicUpdateSlice);
862   CHECK_NE(dus, nullptr);
863 
864   // Use a memoization map to avoid exponential runtime.
865   absl::flat_hash_map<const HloInstruction*, bool> nonelementwise_memo;
866   // Recursively check if the instruction or its users (or their users) are
867   // non-elementwise with the exception of the DUS. We have already verified
868   // that the slice and DUS are compatible since their indices match.
869   HloPredicate has_nonelementwise_uses_except_dus =
870       [&](const HloInstruction* instruction) {
871         auto record_and_return = [&](bool val) {
872           nonelementwise_memo[instruction] = val;
873           return val;
874         };
875         auto nonelementwise_memo_it = nonelementwise_memo.find(instruction);
876         if (nonelementwise_memo_it != nonelementwise_memo.end()) {
877           return nonelementwise_memo_it->second;
878         }
879         if (instruction != dus && !instruction->IsElementwise() &&
880             instruction->opcode() != HloOpcode::kParameter) {
881           return record_and_return(true);
882         }
883         return record_and_return(absl::c_any_of(
884             instruction->users(), has_nonelementwise_uses_except_dus));
885       };
886   for (int i = 0; i < consumer->operand_count(); ++i) {
887     if (consumer->operand(i) == producer &&
888         has_nonelementwise_uses_except_dus(consumer->fused_parameter(i))) {
889       VLOG(4) << "Found a different elementwise";
890       return false;
891     }
892   }
893   return true;
894 }
895 
896 }  // namespace
897 
ShouldFuseInPlaceOp(const HloInstruction * producer,const HloInstruction * consumer)898 /*static*/ FusionDecision InstructionFusion::ShouldFuseInPlaceOp(
899     const HloInstruction* producer, const HloInstruction* consumer) {
900   // Don't fuse if the producer is a non-elementwise op that has the same
901   // operand as an in-place operand of the consumer. The consumer will modify
902   // the buffer in-place, which will cause producer's operand to change if we
903   // allow them to fuse.
904   std::vector<std::pair<HloOperandIndex, ShapeIndex>>
905       in_place_input_output_pairs =
906           HloDataflowAnalysis::GetInPlaceInputOutputPairs(
907               const_cast<HloInstruction*>(consumer));
908   for (auto& pair : in_place_input_output_pairs) {
909     int operand_number = pair.first.operand_number;
910     VLOG(4) << "in/out pair: " << operand_number << " "
911             << pair.first.operand_index.ToString() << " "
912             << pair.second.ToString();
913     // Check if the consumer also has an additional operand that has the same
914     // value as the in-place buffer. If so, it's unsafe to fuse.
915     for (int i = 0; i < consumer->operand_count(); ++i) {
916       if (i != operand_number &&
917           consumer->operand(operand_number) == consumer->operand(i)) {
918         return "The consumer is an in-place operation that has an additional "
919                "operand that has the same value as the in-place buffer";
920       }
921     }
922     if (!producer->IsElementwise() &&
923         absl::c_find(producer->operands(), consumer->operand(operand_number)) !=
924             producer->operands().end()) {
925       VLOG(4) << "Found non-elementwise operand that uses the same operand of "
926                  "an in-place consumer";
927       auto get_real_operand = [](const HloInstruction* op,
928                                  const HloInstruction* operand) {
929         if (op->opcode() == HloOpcode::kFusion &&
930             operand->opcode() == HloOpcode::kParameter) {
931           return op->operand(operand->parameter_number());
932         }
933         return operand;
934       };
935 
936       auto get_constant_operand =
937           [](const HloInstruction* operand) -> std::optional<int> {
938         if (operand->IsConstant()) {
939           return operand->literal().GetFirstInteger();
940         }
941         return std::nullopt;
942       };
943       // A common special case is a slice or dynamic-slice and a
944       // dynamic-update-slice that use the same indices. This pattern is safe.
945       const HloInstruction* dus =
946           ExtractInstruction(consumer, HloOpcode::kDynamicUpdateSlice);
947       const HloInstruction* producer_nonelementwise =
948           ExtractInstruction(producer, [](const HloInstruction* inst) {
949             return inst->opcode() != HloOpcode::kFusion &&
950                    !inst->IsElementwise();
951           });
952       if (dus == nullptr || producer_nonelementwise == nullptr ||
953           producer_nonelementwise->shape() != dus->operand(1)->shape()) {
954         return "Consumer is not a dus or the producer fusion has multiple "
955                "non-elementwise ops, bailing.";
956       }
957       if (producer_nonelementwise->opcode() == HloOpcode::kSlice) {
958         for (int i = 0; i < dus->shape().rank(); ++i) {
959           const HloInstruction* dus_operand =
960               get_real_operand(consumer, dus->operand(2 + i));
961           auto constant_operand = get_constant_operand(dus_operand);
962           if (!constant_operand ||
963               *constant_operand != producer_nonelementwise->slice_starts(i) ||
964               producer_nonelementwise->slice_strides(i) != 1) {
965             return "DUS and slice index mismatch";
966           }
967         }
968         VLOG(4) << "DUS and slice index match";
969         if (consumer->opcode() == HloOpcode::kFusion &&
970             !IsSafeToFuseSliceIntoDusFusion(producer, consumer)) {
971           return "Fusing slice into DUS will also fuse another non-elementwise "
972                  "op with shared operand as DUS.";
973         }
974         return {};
975       }
976       if (producer_nonelementwise->opcode() == HloOpcode::kDynamicSlice) {
977         for (int i = 0; i < dus->shape().rank(); ++i) {
978           const HloInstruction* ds_operand = get_real_operand(
979               producer, producer_nonelementwise->operand(1 + i));
980           const HloInstruction* dus_operand =
981               get_real_operand(consumer, dus->operand(2 + i));
982           auto constant_ds_operand = get_constant_operand(ds_operand);
983           auto constant_dus_operand = get_constant_operand(dus_operand);
984           if (constant_ds_operand != constant_dus_operand ||
985               (!constant_ds_operand && ds_operand != dus_operand)) {
986             return "DUS and DS index mismatch";
987           }
988         }
989         VLOG(4) << "DUS and DS index match";
990         if (consumer->opcode() == HloOpcode::kFusion &&
991             !IsSafeToFuseSliceIntoDusFusion(producer, consumer)) {
992           return "Fusing DS into DUS will also fuse another non-elementwise op "
993                  "with shared operand as DUS.";
994         }
995         return {};
996       }
997       return "unrecognized inplace update non-elementwise output pair";
998     }
999   }
1000   return {};
1001 }
1002 
ShouldFuse(HloInstruction * consumer,int64_t operand_index)1003 FusionDecision InstructionFusion::ShouldFuse(HloInstruction* consumer,
1004                                              int64_t operand_index) {
1005   HloInstruction* producer = consumer->mutable_operand(operand_index);
1006 
1007   // Don't fuse across a root instruction.
1008   if (producer == producer->parent()->root_instruction()) {
1009     return "not fusing into the output of the root instruction";
1010   }
1011 
1012   // Cost condition: don't duplicate expensive instructions.
1013   if (FusionWouldDuplicate(*producer, *consumer) &&
1014       (!may_duplicate_ || is_expensive_(*producer)) &&
1015       !IsAlwaysDuplicable(*producer)) {
1016     return may_duplicate_ ? "expensive producer would be duplicated"
1017                           : "fusion pass cannot duplicate";
1018   }
1019 
1020   if (NoFusionPossible fusible = !ShouldFuseInPlaceOp(producer, consumer)) {
1021     return !fusible;
1022   }
1023 
1024   return {};
1025 }
1026 
ChooseKind(const HloInstruction * producer,const HloInstruction * consumer)1027 HloInstruction::FusionKind InstructionFusion::ChooseKind(
1028     const HloInstruction* producer, const HloInstruction* consumer) {
1029   return HloInstruction::FusionKind::kLoop;
1030 }
1031 
ReusedOperandsOf(const HloInstruction * instruction)1032 absl::flat_hash_set<const HloInstruction*>& InstructionFusion::ReusedOperandsOf(
1033     const HloInstruction* instruction) {
1034   std::unique_ptr<absl::flat_hash_set<const HloInstruction*>>& reused_operands =
1035       reused_fusion_operands_[instruction];
1036   if (reused_operands != nullptr) {
1037     return *reused_operands;
1038   }
1039   reused_operands =
1040       std::make_unique<absl::flat_hash_set<const HloInstruction*>>();
1041 
1042   for (int64_t i = 0; i < instruction->operand_count(); ++i) {
1043     bool reuses = instruction->ReusesOperandElements(i);
1044     if (reuses) {
1045       // We cache the operand corresponding to the fusion parameter, because the
1046       // parameter numbers would be invalidated after the next fusion.
1047       reused_operands->insert(instruction->operand(i));
1048     }
1049   }
1050   return *reused_operands;
1051 }
1052 
ReusesOperandElements(const HloInstruction * consumer,int64_t operand_index)1053 bool InstructionFusion::ReusesOperandElements(const HloInstruction* consumer,
1054                                               int64_t operand_index) {
1055   auto operand = consumer->operand(operand_index);
1056   return ReusedOperandsOf(consumer).contains(operand);
1057 }
1058 
1059 }  // namespace xla
1060