1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/instruction_fusion.h"
17
18 #include <algorithm>
19 #include <functional>
20 #include <list>
21 #include <memory>
22 #include <numeric>
23 #include <string>
24 #include <utility>
25 #include <vector>
26
27 #include "absl/algorithm/container.h"
28 #include "absl/container/flat_hash_map.h"
29 #include "absl/container/flat_hash_set.h"
30 #include "tensorflow/compiler/xla/debug_options_flags.h"
31 #include "tensorflow/compiler/xla/map_util.h"
32 #include "tensorflow/compiler/xla/service/fusion_queue.h"
33 #include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
34 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
35 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
36 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
37 #include "tensorflow/compiler/xla/service/hlo_reachability.h"
38 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
39 #include "tensorflow/core/lib/core/errors.h"
40 #include "tensorflow/core/platform/logging.h"
41
42 namespace xla {
43 namespace {
44
45 // These nodes can always be duplicated into consumers, even if
46 // InstructionFusion::may_duplicate_ is false.
47 //
48 // In general these should be nodes that get *cheaper* the more they're
49 // duplicated (and fused into consumers).
50 //
51 // TODO(jlebar): Duplicating instructions when we have a variable called "may
52 // duplicate" that's equal to false is not pretty.
IsAlwaysDuplicable(const HloInstruction & instruction)53 bool IsAlwaysDuplicable(const HloInstruction& instruction) {
54 // We are always willing to duplicate a widening type-conversion instruction
55 // if it means we can fuse the convert into a consumer. This allows the
56 // consumer to read less memory, which is almost always a performance win.
57 return instruction.opcode() == HloOpcode::kConvert &&
58 ShapeUtil::ByteSizeOf(instruction.operand(0)->shape()) <
59 ShapeUtil::ByteSizeOf(instruction.shape());
60 }
61 } // namespace
62
IsExpensive(const HloInstruction & instruction)63 /*static*/ bool InstructionFusion::IsExpensive(
64 const HloInstruction& instruction) {
65 namespace m = match;
66
67 switch (instruction.opcode()) {
68 // Cheap instructions.
69 case HloOpcode::kAdd:
70 case HloOpcode::kAnd:
71 case HloOpcode::kBitcast:
72 case HloOpcode::kBitcastConvert:
73 case HloOpcode::kBroadcast:
74 case HloOpcode::kCeil:
75 case HloOpcode::kClamp:
76 case HloOpcode::kClz:
77 case HloOpcode::kCompare:
78 case HloOpcode::kComplex:
79 case HloOpcode::kConcatenate:
80 case HloOpcode::kConstant:
81 case HloOpcode::kConvert:
82 case HloOpcode::kCopy:
83 case HloOpcode::kCopyDone:
84 case HloOpcode::kCopyStart:
85 case HloOpcode::kDynamicSlice:
86 case HloOpcode::kDynamicUpdateSlice:
87 case HloOpcode::kFloor:
88 case HloOpcode::kGetTupleElement:
89 case HloOpcode::kImag:
90 case HloOpcode::kInfeed:
91 case HloOpcode::kIota:
92 case HloOpcode::kIsFinite:
93 case HloOpcode::kMaximum:
94 case HloOpcode::kMinimum:
95 case HloOpcode::kMultiply:
96 case HloOpcode::kNegate:
97 case HloOpcode::kNot:
98 case HloOpcode::kOptimizationBarrier:
99 case HloOpcode::kOr:
100 case HloOpcode::kXor:
101 case HloOpcode::kOutfeed:
102 case HloOpcode::kPad:
103 case HloOpcode::kPartitionId:
104 case HloOpcode::kPopulationCount:
105 case HloOpcode::kReal:
106 case HloOpcode::kReducePrecision:
107 case HloOpcode::kReplicaId:
108 case HloOpcode::kReshape:
109 case HloOpcode::kDynamicReshape:
110 case HloOpcode::kReverse:
111 case HloOpcode::kRoundNearestAfz:
112 case HloOpcode::kRoundNearestEven:
113 case HloOpcode::kSelect:
114 case HloOpcode::kShiftLeft:
115 case HloOpcode::kShiftRightArithmetic:
116 case HloOpcode::kShiftRightLogical:
117 case HloOpcode::kSlice:
118 case HloOpcode::kSubtract:
119 case HloOpcode::kTranspose:
120 case HloOpcode::kTuple:
121 return false;
122
123 // Cheap instructions for reals, but expensive for complex.
124 case HloOpcode::kAbs:
125 case HloOpcode::kCos:
126 case HloOpcode::kSign:
127 case HloOpcode::kSin:
128 return ShapeUtil::ElementIsComplex(instruction.shape());
129
130 // We say that integer div/mod by a constant is cheap because it gets
131 // compiled down to multiplies and shifts, and we consider those to be
132 // cheap.
133 case HloOpcode::kDivide:
134 case HloOpcode::kRemainder:
135 return !ShapeUtil::ElementIsIntegral(instruction.shape()) ||
136 !Match(instruction.operand(1),
137 m::AnyOf<const HloInstruction>(
138 m::ConstantEffectiveScalar(),
139 m::Broadcast(m::ConstantEffectiveScalar())));
140
141 // Expensive instructions or unusual instructions for which fusion is
142 // nonsensical.
143 case HloOpcode::kAddDependency:
144 case HloOpcode::kAfterAll:
145 case HloOpcode::kAtan2:
146 case HloOpcode::kAsyncStart:
147 case HloOpcode::kAsyncUpdate:
148 case HloOpcode::kAsyncDone:
149 case HloOpcode::kBatchNormGrad:
150 case HloOpcode::kBatchNormInference:
151 case HloOpcode::kBatchNormTraining:
152 case HloOpcode::kCall:
153 case HloOpcode::kCholesky:
154 case HloOpcode::kConditional:
155 case HloOpcode::kConvolution:
156 case HloOpcode::kAllGather:
157 case HloOpcode::kAllGatherStart:
158 case HloOpcode::kAllGatherDone:
159 case HloOpcode::kAllReduce:
160 case HloOpcode::kReduceScatter:
161 case HloOpcode::kAllReduceStart:
162 case HloOpcode::kAllReduceDone:
163 case HloOpcode::kAllToAll:
164 case HloOpcode::kCollectivePermute:
165 case HloOpcode::kCollectivePermuteDone:
166 case HloOpcode::kCollectivePermuteStart:
167 case HloOpcode::kCustomCall:
168 case HloOpcode::kDomain:
169 case HloOpcode::kDot:
170 case HloOpcode::kExp:
171 case HloOpcode::kExpm1:
172 case HloOpcode::kFft:
173 case HloOpcode::kFusion:
174 case HloOpcode::kGather:
175 case HloOpcode::kLog:
176 case HloOpcode::kLog1p:
177 case HloOpcode::kLogistic:
178 case HloOpcode::kMap:
179 case HloOpcode::kParameter:
180 case HloOpcode::kPower:
181 case HloOpcode::kRecv:
182 case HloOpcode::kRecvDone:
183 case HloOpcode::kReduce:
184 case HloOpcode::kReduceWindow:
185 case HloOpcode::kRng:
186 case HloOpcode::kRngGetAndUpdateState:
187 case HloOpcode::kRngBitGenerator:
188 case HloOpcode::kRsqrt:
189 case HloOpcode::kScatter:
190 case HloOpcode::kSelectAndScatter:
191 case HloOpcode::kSend:
192 case HloOpcode::kSendDone:
193 case HloOpcode::kSort:
194 case HloOpcode::kSqrt:
195 case HloOpcode::kCbrt:
196 case HloOpcode::kTanh:
197 case HloOpcode::kTriangularSolve:
198 case HloOpcode::kWhile:
199 case HloOpcode::kGetDimensionSize:
200 case HloOpcode::kSetDimensionSize:
201 return true;
202 }
203
204 return false;
205 }
206
207 // An "effectively at most unary" operation is one that has at most one "large"
208 // input with the others being negligible in terms of memory usage.
209 // We use "has a smaller true rank than the output" as a heuristic
210 // for "negligible" memory usage.
EffectivelyAtMostUnary(HloInstruction * hlo)211 bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) {
212 int64_t output_rank = 0;
213 ShapeUtil::ForEachSubshape(
214 hlo->shape(),
215 [&output_rank](const Shape& subshape, const ShapeIndex& shape_index) {
216 if (subshape.IsArray()) {
217 output_rank = std::max(output_rank, ShapeUtil::TrueRank(subshape));
218 }
219 });
220 return absl::c_count_if(
221 hlo->operands(), [output_rank](HloInstruction* operand) {
222 if (operand->opcode() == HloOpcode::kBroadcast ||
223 operand->opcode() == HloOpcode::kIota) {
224 return false;
225 }
226 if (operand->opcode() == HloOpcode::kConstant &&
227 ShapeUtil::IsEffectiveScalar(operand->shape())) {
228 return false;
229 }
230 return ShapeUtil::TrueRank(operand->shape()) >= output_rank;
231 }) <= 1;
232 }
233
CanFuseOnAllPaths(HloInstruction * producer,HloInstruction * consumer,const HloInstructionSet & do_not_fuse,const HloReachabilityMap & reachability,absl::flat_hash_map<std::pair<HloInstruction *,HloInstruction * >,bool> * result_cache)234 bool InstructionFusion::CanFuseOnAllPaths(
235 HloInstruction* producer, HloInstruction* consumer,
236 const HloInstructionSet& do_not_fuse,
237 const HloReachabilityMap& reachability,
238 absl::flat_hash_map<std::pair<HloInstruction*, HloInstruction*>, bool>*
239 result_cache) {
240 if (consumer == producer) {
241 return true;
242 }
243 if (!consumer->IsFusible()) {
244 return false;
245 }
246 auto cache_it = result_cache->find(std::make_pair(producer, consumer));
247 if (cache_it != result_cache->end()) {
248 return cache_it->second;
249 }
250 bool result = true;
251 for (int64_t i = 0, e = consumer->operand_count(); i < e; ++i) {
252 auto* consumer_operand = consumer->mutable_operand(i);
253 // If the operand is not on a path to the producer, it doesn't matter
254 // whether it's fusible.
255 if (!reachability.IsReachable(producer, consumer_operand)) {
256 continue;
257 }
258 if (do_not_fuse.count(consumer_operand) > 0 || !ShouldFuse(consumer, i)) {
259 result = false;
260 break;
261 }
262 // The producer is reachable from consumer_operand which means we need
263 // to be able to fuse consumer_operand into consumer in order for
264 // producer to be fusible into consumer on all paths.
265 // Perform the recursive step: make sure producer can be fused into
266 // consumer_operand on all paths.
267 if (!CanFuseOnAllPaths(producer, consumer_operand, do_not_fuse,
268 reachability, result_cache)) {
269 result = false;
270 break;
271 }
272 }
273 result_cache->emplace(std::make_pair(producer, consumer), result);
274 return result;
275 }
276
277 InstructionFusion::HloInstructionSet
ComputeGloballyUnfusible(absl::Span<HloInstruction * const> post_order,const HloReachabilityMap & reachability)278 InstructionFusion::ComputeGloballyUnfusible(
279 absl::Span<HloInstruction* const> post_order,
280 const HloReachabilityMap& reachability) {
281 // Forbid fusion of producers that:
282 // a) Need to be duplicated, unless they can be fused into all consumers
283 // via all paths.
284 // b) Are more than unary, that is, fusing them would likely lead to an
285 // increase in memory bandwidth use.
286 //
287 // Note that if we allow fusion by these global rules, we may still forbid
288 // fusing operations that require duplication later depending on
289 // is_expensive_().
290 HloInstructionSet do_not_duplicate;
291 absl::flat_hash_map<std::pair<HloInstruction*, HloInstruction*>, bool>
292 can_fuse_on_all_paths_result_cache;
293 for (auto it = post_order.rbegin(); it != post_order.rend(); ++it) {
294 HloInstruction* producer = *it;
295 // If the producer is effectively not more than unary, duplicating it
296 // will not increase the number of relevant inputs read, as the fusion
297 // node will only need to read at most 1 relevant input (the input of
298 // the producer). In that case, we do not forbid fusion of the operation
299 // here.
300 if (EffectivelyAtMostUnary(producer)) {
301 continue;
302 }
303
304 // If the total size of the inputs is less than or equal to the total size
305 // of the outputs for the producer then duplicating it won't increase the
306 // memory traffic. In that case, we do not forbid fusion of the operation
307 // here.
308 auto total_size = [](const Shape& shape) {
309 int64_t size = 0;
310 ShapeUtil::ForEachSubshape(
311 shape, [&size](const Shape& subshape, const ShapeIndex& shape_index) {
312 if (subshape.IsArray()) {
313 size += ShapeUtil::ElementsIn(subshape);
314 }
315 });
316 return size;
317 };
318 int64_t operands_size = 0;
319 for (const HloInstruction* op : producer->unique_operands()) {
320 operands_size += total_size(op->shape());
321 }
322 if (operands_size <= total_size(producer->shape())) {
323 continue;
324 }
325
326 // Otherwise we will forbid fusing the op unless we can fuse it into
327 // all of its consumers on all paths.
328 //
329 // That means, that for:
330 // A --> B (fusible)
331 // \-> C (non-fusible)
332 // A will be not allowed to be fused into B, as it cannot be fused into C.
333 //
334 // Similarly, for:
335 // A -------------> B
336 // \-> C -> D -/
337 // If:
338 // - A is fusible into B and C, and D is fusible into B
339 // - C is *not* fusible into D
340 // A will be not allowed to be fused into B, as it cannot be fused via
341 // all paths.
342 if (producer->IsFusible() &&
343 absl::c_all_of(producer->users(), [&](HloInstruction* consumer) {
344 return CanFuseOnAllPaths(producer, consumer, do_not_duplicate,
345 reachability,
346 &can_fuse_on_all_paths_result_cache);
347 })) {
348 continue;
349 }
350 do_not_duplicate.insert(producer);
351 }
352
353 return do_not_duplicate;
354 }
355
356 namespace {
357
358 // A FusionQueue that uses reverse post order.
359 //
360 // We want to be able to remove arbitrary instructions from the post order and
361 // also compare positions of instructions in the post order. To make this
362 // possible, create vector of instructions in post order and create a map from
363 // HloInstruction* to the instruction's index in the vector. An instruction is
364 // "removed" from the vector by setting it's element to nullptr.
365 class ReversePostOrderFusionQueue : public FusionQueue {
366 public:
ReversePostOrderFusionQueue(HloComputation * computation)367 explicit ReversePostOrderFusionQueue(HloComputation* computation) {
368 post_order_ = computation->MakeInstructionPostOrder();
369
370 for (size_t i = 0; i < post_order_.size(); ++i) {
371 InsertOrDie(&post_order_index_, post_order_[i], i);
372 }
373 }
374
375 std::pair<HloInstruction*, std::vector<int64_t>>
DequeueNextInstructionAndOperandsToFuseInOrder()376 DequeueNextInstructionAndOperandsToFuseInOrder() override {
377 // Instructions are "removed" from the post order by nulling out the element
378 // in the vector, so if the pointer is null, continue to the next
379 // instruction in the sort.
380 while (!post_order_.empty() && post_order_.back() == nullptr) {
381 post_order_.pop_back();
382 }
383 if (post_order_.empty()) {
384 return std::pair<HloInstruction*, std::vector<int64_t>>{nullptr, {}};
385 }
386 // We want to iterate in reverse post order, so remove from the back of the
387 // vector.
388 HloInstruction* instruction = post_order_.back();
389 post_order_.pop_back();
390
391 CHECK(instruction != nullptr);
392 // Remove instruction from the index map to ensure the vector and map stay
393 // consistent.
394 post_order_index_.erase(instruction);
395
396 // Consider each operand of this instruction for fusion into this
397 // instruction. We want to consider the operands in a particular order to
398 // avoid creating duplicate instruction clones in the fusion instruction.
399 // For example, consider the following expression:
400 //
401 // A = ...
402 // B = op(A)
403 // C = op(A, B)
404 //
405 // If we are considering the operands of C for fusion into C. We might
406 // fuse A or B first. If we fuse A first, we get:
407 //
408 // A = ...
409 // B = op(A)
410 // C_fusion = { A' = ...
411 // C' = op(A', B) }
412 //
413 // Where A' and C' are clones of A and C, respectively. Now only B is an
414 // operand of the fusion instruction C_fusion, so then we fuse B:
415 //
416 // A = ...
417 // B = op(A)
418 // C_fusion = { A' = ...
419 // B' = op(A)
420 // C' = op(A', B') }
421 //
422 // Now A is an operand of C_fusion again, so we then fuse A (again!):
423 //
424 // A = ...
425 // B = op(A)
426 // C_fusion = { A' = ...
427 // A" = ..
428 // B' = op(A")
429 // C' = op(A', B') }
430 //
431 // We prevent this duplication by considering the operands in the order
432 // they appear int the queue. In the example, this ensures that B will be
433 // considered before A.
434 //
435 // We store the original indices of the operands to pass to ShouldFuse.
436 std::vector<int64_t> sorted_operand_numbers;
437 sorted_operand_numbers.reserve(instruction->operands().size());
438 for (int i = 0; i < instruction->operands().size(); ++i) {
439 // This will happen if we have two possible instructions to fuse the
440 // same operand into; once the operand is fused into one instruction,
441 // the other instruction will get a new get-tuple-element as its
442 // operand, which is not in the queue.
443 // TODO(tjoerg): Look into fusing past these multi-output fuse points.
444 if (!ContainsKey(post_order_index_, instruction->mutable_operand(i))) {
445 continue;
446 }
447 sorted_operand_numbers.push_back(i);
448 }
449 absl::c_sort(sorted_operand_numbers, [&](int64_t i, int64_t j) {
450 // Instructions with higher priority in the queue come first.
451 return (FindOrDie(post_order_index_, instruction->mutable_operand(i)) >
452 FindOrDie(post_order_index_, instruction->mutable_operand(j)));
453 });
454 return std::make_pair(instruction, sorted_operand_numbers);
455 }
456
OnFusingInstruction(HloInstruction * fusion,HloInstruction * original_producer,HloInstruction * original_consumer)457 void OnFusingInstruction(HloInstruction* fusion,
458 HloInstruction* original_producer,
459 HloInstruction* original_consumer) override {
460 // Fusing an instruction into a fusion instruction can change the operand
461 // set of the fusion instruction. For simplicity just re-enqueue the
462 // instruction and reconsider it for further fusion in the next iteration.
463 InsertOrDie(&post_order_index_, fusion, post_order_.size());
464 post_order_.push_back(fusion);
465 }
466
RemoveInstruction(HloInstruction * instruction)467 void RemoveInstruction(HloInstruction* instruction) override {
468 post_order_[FindOrDie(post_order_index_, instruction)] = nullptr;
469 post_order_index_.erase(instruction);
470 }
471
FusionConfiguration()472 const std::vector<bool>* FusionConfiguration() override {
473 return &fusion_config_;
474 }
475
476 private:
477 std::vector<HloInstruction*> post_order_;
478 absl::flat_hash_map<HloInstruction*, int> post_order_index_;
479 std::vector<bool> fusion_config_;
480 };
481
482 } // namespace
483
GetFusionComputations(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)484 std::vector<HloComputation*> InstructionFusion::GetFusionComputations(
485 HloModule* module,
486 const absl::flat_hash_set<absl::string_view>& execution_threads) {
487 // Use sorted computations because fusion configuration is order-sensitive.
488 return module->MakeNonfusionComputationsSorted(execution_threads);
489 }
490
GetFusionQueue(HloComputation * computation)491 std::unique_ptr<FusionQueue> InstructionFusion::GetFusionQueue(
492 HloComputation* computation) {
493 return std::make_unique<ReversePostOrderFusionQueue>(computation);
494 }
495
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)496 StatusOr<bool> InstructionFusion::Run(
497 HloModule* module,
498 const absl::flat_hash_set<absl::string_view>& execution_threads) {
499 bool changed = false;
500 int64_t fuse_count = 0;
501 std::vector<std::vector<bool>>* fusion_config = nullptr;
502 HloModuleConfig module_config;
503 if (config_collection_mode_ != FusionConfigCollection::kOff) {
504 module_config = module->config();
505 fusion_config = module_config.mutable_fusion_config();
506 fusion_config->clear();
507 }
508
509 bool dump_fusion =
510 module->config().debug_options().xla_dump_fusion_visualization();
511
512 for (auto* computation : GetFusionComputations(module, execution_threads)) {
513 CHECK(!computation->IsFusionComputation());
514 std::unique_ptr<HloReachabilityMap> reachability =
515 HloReachabilityMap::Build(computation);
516
517 HloInstructionSet do_not_duplicate;
518 // If we allow duplications, we need to compute which instructions we do not
519 // want to duplicate based on a global analysis of the graph.
520 if (may_duplicate_) {
521 do_not_duplicate = ComputeGloballyUnfusible(
522 computation->MakeInstructionPostOrder(), *reachability);
523 }
524 auto fusion_queue = GetFusionQueue(computation);
525
526 // Instruction fusion effectively fuses edges in the computation graph
527 // (producer instruction -> consumer instruction) so we iterate over all
528 // edges. When we fuse an edge, we create a copy of the producer inside the
529 // fusion instruction.
530 while (true) {
531 std::pair<HloInstruction*, std::vector<int64_t>> next_entry =
532 fusion_queue->DequeueNextInstructionAndOperandsToFuseInOrder();
533 HloInstruction* instruction = next_entry.first;
534 if (instruction == nullptr) {
535 break;
536 }
537
538 if (!instruction->IsFusible() &&
539 instruction->opcode() != HloOpcode::kFusion) {
540 continue;
541 }
542
543 std::vector<int64_t>& sorted_operand_numbers = next_entry.second;
544
545 for (int64_t i : sorted_operand_numbers) {
546 HloInstruction* operand = instruction->mutable_operand(i);
547 VLOG(5) << "Considering fusion of: " << instruction->ToString()
548 << " with operand " << operand->name();
549
550 if (!operand->IsFusible()) {
551 VLOG(3) << "Operand (" << operand->ToString() << ") is not fusible";
552 continue;
553 }
554
555 // Consumes a unit of compiler fuel and returns true if we should
556 // continue with the transformation.
557 auto consume_fuel = [&] {
558 return ConsumeFuel(name(), /*ran_out_of_fuel_msg=*/[&] {
559 return absl::StrFormat("Not fusing operand %d of %s, namely, %s", i,
560 instruction->ToString(),
561 operand->ToString());
562 });
563 };
564
565 HloInstruction* fusion_instruction = nullptr;
566
567 FusionDecision should_fuse(do_not_duplicate.count(operand) == 0,
568 "operand can not be duplicated");
569
570 // Try "regular" fusion if the operand may be duplicated. Otherwise,
571 // perform multi-output fusion, unless this creates a cycle.
572 if (should_fuse) {
573 should_fuse = ShouldFuse(instruction, i);
574 if (should_fuse && consume_fuel()) {
575 if (dump_fusion) {
576 RegisterFusionState(
577 *computation,
578 absl::StrCat("About to fuse |", operand->name(), "| into |",
579 instruction->name(),
580 "| inside InstructionFusion with may_duplicate=",
581 may_duplicate_),
582 /*consumer=*/*instruction,
583 /*producer=*/operand);
584 }
585
586 fusion_queue->PreFusion(operand, instruction);
587 fusion_instruction = Fuse(operand, instruction, computation);
588 }
589 }
590
591 if (!should_fuse) {
592 FusionDecision can_fuse_mof =
593 ShouldFuseIntoMultiOutput(instruction, i);
594 if (can_fuse_mof) {
595 can_fuse_mof = can_fuse_mof.And(
596 FusionDecision{!MultiOutputFusionCreatesCycle(
597 operand, instruction, *reachability),
598 "multi-output fusion creates a cycle"});
599 }
600 if (can_fuse_mof) {
601 if (consume_fuel()) {
602 if (dump_fusion) {
603 RegisterFusionState(
604 *computation,
605 absl::StrCat(
606 "About to MOF-fuse |", operand->name(), "| into |",
607 instruction->name(),
608 "| inside InstructionFusion with may_duplicate=",
609 may_duplicate_),
610 /*consumer=*/*instruction, /*producer=*/operand);
611 }
612
613 fusion_queue->PreFusion(operand, instruction);
614 fusion_instruction =
615 FuseIntoMultiOutput(operand, instruction, computation);
616 }
617 }
618 should_fuse = should_fuse.Or(can_fuse_mof);
619 }
620
621 if (fusion_instruction == nullptr) {
622 CHECK(!should_fuse.CanFuse());
623 if (dump_fusion) {
624 VLOG(2) << "Not fusing " << operand->ToShortString() << "| into |"
625 << instruction->ToShortString() << "| as "
626 << should_fuse.Explain();
627
628 // Readability optimizations: lack of fusion for tuple accesses
629 // generates a lot of noise.
630 if (operand->opcode() != HloOpcode::kGetTupleElement &&
631 instruction->opcode() != HloOpcode::kGetTupleElement) {
632 RegisterFusionState(*computation,
633 absl::StrCat("Not fusing |", operand->name(),
634 "| into |", instruction->name(),
635 "| as ", should_fuse.Explain()),
636 /*consumer=*/*instruction,
637 /*producer=*/operand);
638 }
639 }
640
641 fusion_queue->NotFusingInstruction(operand, instruction);
642 continue;
643 }
644
645 // Saving name to use after the instruction is removed.
646 std::string producer_name = operand->name();
647 fusion_queue->OnFusingInstruction(fusion_instruction, operand,
648 instruction);
649 changed = true;
650 ++fuse_count;
651
652 if (operand->user_count() == 0) {
653 do_not_duplicate.erase(operand);
654 // Operand is now dead. Remove from queue.
655 fusion_queue->RemoveInstruction(operand);
656 // Remove from computation.
657 TF_RETURN_IF_ERROR(computation->RemoveInstruction(operand));
658 }
659
660 if (dump_fusion) {
661 RegisterFusionState(
662 *computation,
663 absl::StrCat("Fused |", producer_name, "| into |",
664 fusion_instruction->name(),
665 "| inside InstructionFusion with may_duplicate=",
666 may_duplicate_),
667 *fusion_instruction);
668 }
669
670 if (fusion_instruction != instruction) {
671 do_not_duplicate.erase(instruction);
672 }
673 break;
674 }
675 }
676
677 if (config_collection_mode_ != FusionConfigCollection::kOff) {
678 const std::vector<bool>* comp_fusion_config =
679 fusion_queue->FusionConfiguration();
680 if (comp_fusion_config && !comp_fusion_config->empty()) {
681 fusion_config->push_back(*comp_fusion_config);
682 }
683 }
684 }
685
686 if (config_collection_mode_ != FusionConfigCollection::kOff) {
687 int64_t fused_count = 0;
688 for (auto& config_per_computation : *fusion_config) {
689 for (auto edge : config_per_computation) {
690 if (edge) {
691 ++fused_count;
692 }
693 }
694 }
695 VLOG(1) << "There are " << fused_count << " fused bits that cause "
696 << fuse_count << " fusion actions.";
697 module->set_config(module_config);
698 }
699
700 VLOG(1) << "Fusion count: " << fuse_count;
701
702 return changed;
703 }
704
AddFusionInstruction(HloInstruction * producer,HloInstruction * consumer,HloComputation * computation)705 HloInstruction* InstructionFusion::AddFusionInstruction(
706 HloInstruction* producer, HloInstruction* consumer,
707 HloComputation* computation) {
708 HloInstruction* fusion_instruction;
709 auto kind = ChooseKind(producer, consumer);
710 if (consumer->opcode() == HloOpcode::kFusion) {
711 fusion_instruction = consumer;
712 if (kind != fusion_instruction->fusion_kind()) {
713 fusion_instruction->set_fusion_kind(kind);
714 }
715 } else {
716 fusion_instruction = computation->AddInstruction(
717 HloInstruction::CreateFusion(consumer->shape(), kind, consumer));
718 TF_CHECK_OK(computation->ReplaceInstruction(consumer, fusion_instruction));
719 }
720 fusion_instruction->set_called_computations_execution_thread(
721 computation->execution_thread(),
722 /*skip_async_execution_thread_overwrite=*/false);
723 return fusion_instruction;
724 }
725
FuseInstruction(HloInstruction * fusion_instruction,HloInstruction * producer)726 HloInstruction* InstructionFusion::FuseInstruction(
727 HloInstruction* fusion_instruction, HloInstruction* producer) {
728 return fusion_instruction->FuseInstruction(producer);
729 }
730
UpdateReusedOperandsForFusion(HloInstruction * producer,HloInstruction * fusion_instruction)731 void InstructionFusion::UpdateReusedOperandsForFusion(
732 HloInstruction* producer, HloInstruction* fusion_instruction) {
733 // Find or compute the existing fusion reused operands. Note these reflect the
734 // state *before* the current fusion has taken place, although if we have
735 // replaced the consumer with a new single-element fusion, we will compute
736 // the new single-element fusion's reused operands here.
737 absl::flat_hash_set<const HloInstruction*>& fusion_reused_operands =
738 ReusedOperandsOf(fusion_instruction);
739
740 // If the producer is reused, replace it with its operands.
741 if (fusion_reused_operands.erase(producer)) {
742 fusion_reused_operands.insert(producer->operands().begin(),
743 producer->operands().end());
744 } else {
745 const absl::flat_hash_set<const HloInstruction*>& producer_reused_operands =
746 ReusedOperandsOf(producer);
747 // Otherwise add the producer's reused operands to the set.
748 fusion_reused_operands.insert(producer_reused_operands.begin(),
749 producer_reused_operands.end());
750 }
751 }
752
Fuse(HloInstruction * producer,HloInstruction * consumer,HloComputation * computation)753 HloInstruction* InstructionFusion::Fuse(HloInstruction* producer,
754 HloInstruction* consumer,
755 HloComputation* computation) {
756 VLOG(2) << "Fusing " << producer->ToString() << " into "
757 << consumer->ToString();
758 HloInstruction* fusion_instruction =
759 AddFusionInstruction(producer, consumer, computation);
760 UpdateReusedOperandsForFusion(producer, fusion_instruction);
761 FuseInstruction(fusion_instruction, producer);
762 if (fusion_instruction != producer && fusion_instruction != consumer) {
763 VLOG(2) << " created new fusion: " << fusion_instruction->ToString();
764 }
765 return fusion_instruction;
766 }
767
FuseIntoMultiOutput(HloInstruction * producer,HloInstruction * consumer,HloComputation * computation)768 HloInstruction* InstructionFusion::FuseIntoMultiOutput(
769 HloInstruction* producer, HloInstruction* consumer,
770 HloComputation* computation) {
771 VLOG(2) << "Multi-output fusing " << producer->ToString() << " into "
772 << consumer->ToString();
773 HloInstruction* fusion_instruction =
774 AddFusionInstruction(producer, consumer, computation);
775 UpdateReusedOperandsForFusion(producer, fusion_instruction);
776 fusion_instruction->FuseInstructionIntoMultiOutput(producer);
777 return fusion_instruction;
778 }
779
MultiOutputFusionCreatesCycle(HloInstruction * producer,HloInstruction * consumer,const HloReachabilityMap & reachability)780 bool InstructionFusion::MultiOutputFusionCreatesCycle(
781 HloInstruction* producer, HloInstruction* consumer,
782 const HloReachabilityMap& reachability) {
783 absl::flat_hash_set<int> operands;
784 for (const HloInstruction* operand : consumer->operands()) {
785 if (operand == producer) {
786 continue;
787 }
788
789 // If the reachability map already contains the producer and the operand of
790 // the consumer, and the producer can reach the operand, then we know for
791 // sure MultiOutputFusion would create a cycle. If not, we need to do a DFS
792 // traversal of the computation to verify that this multioutput fusion would
793 // not create a cycle.
794 if (reachability.IsPresent(producer) && reachability.IsPresent(operand) &&
795 reachability.IsReachable(producer, operand)) {
796 return true;
797 }
798 operands.insert(operand->unique_id());
799 }
800
801 // Do a DFS on the producer to see if any of the other consumer operands are
802 // reachable in the current state of the graph.
803 std::vector<HloInstruction*> worklist = producer->users();
804 absl::flat_hash_set<int> visits;
805 while (!worklist.empty()) {
806 const HloInstruction* user = worklist.back();
807 worklist.pop_back();
808 if (operands.count(user->unique_id()) != 0) {
809 return true;
810 }
811 if (visits.count(user->unique_id()) == 0) {
812 visits.insert(user->unique_id());
813 worklist.insert(worklist.end(), user->users().begin(),
814 user->users().end());
815 }
816 }
817 return false;
818 }
819
820 namespace {
821
822 // Extracts instruction from the fusion that satisfies filter. If no or multiple
823 // instructions in the fusion satisfy filter, returns nullptr.
ExtractInstruction(const HloInstruction * hlo,const HloPredicate & filter)824 const HloInstruction* ExtractInstruction(const HloInstruction* hlo,
825 const HloPredicate& filter) {
826 if (filter(hlo)) {
827 return hlo;
828 }
829 if (hlo->opcode() != HloOpcode::kFusion) {
830 return nullptr;
831 }
832 const HloInstruction* match = nullptr;
833 for (HloInstruction* inst :
834 hlo->fused_instructions_computation()->instructions()) {
835 if (filter(inst)) {
836 if (match == nullptr) {
837 match = inst;
838 } else {
839 return nullptr;
840 }
841 }
842 }
843 return match;
844 }
845
ExtractInstruction(const HloInstruction * hlo,HloOpcode opcode)846 const HloInstruction* ExtractInstruction(const HloInstruction* hlo,
847 HloOpcode opcode) {
848 return ExtractInstruction(hlo, [opcode](const HloInstruction* inst) {
849 return inst->opcode() == opcode;
850 });
851 }
852
853 // Returns true if fusing a slice or dynamic slice in producer into a dynamic
854 // update slice fusion in consumer is safe. It is not safe to fuse the slice and
855 // DUS when fusing will cause another non-elementwise op to share operands with
856 // the DUS in-place buffer.
IsSafeToFuseSliceIntoDusFusion(const HloInstruction * producer,const HloInstruction * consumer)857 bool IsSafeToFuseSliceIntoDusFusion(const HloInstruction* producer,
858 const HloInstruction* consumer) {
859 CHECK_EQ(consumer->opcode(), HloOpcode::kFusion);
860 const HloInstruction* dus =
861 ExtractInstruction(consumer, HloOpcode::kDynamicUpdateSlice);
862 CHECK_NE(dus, nullptr);
863
864 // Use a memoization map to avoid exponential runtime.
865 absl::flat_hash_map<const HloInstruction*, bool> nonelementwise_memo;
866 // Recursively check if the instruction or its users (or their users) are
867 // non-elementwise with the exception of the DUS. We have already verified
868 // that the slice and DUS are compatible since their indices match.
869 HloPredicate has_nonelementwise_uses_except_dus =
870 [&](const HloInstruction* instruction) {
871 auto record_and_return = [&](bool val) {
872 nonelementwise_memo[instruction] = val;
873 return val;
874 };
875 auto nonelementwise_memo_it = nonelementwise_memo.find(instruction);
876 if (nonelementwise_memo_it != nonelementwise_memo.end()) {
877 return nonelementwise_memo_it->second;
878 }
879 if (instruction != dus && !instruction->IsElementwise() &&
880 instruction->opcode() != HloOpcode::kParameter) {
881 return record_and_return(true);
882 }
883 return record_and_return(absl::c_any_of(
884 instruction->users(), has_nonelementwise_uses_except_dus));
885 };
886 for (int i = 0; i < consumer->operand_count(); ++i) {
887 if (consumer->operand(i) == producer &&
888 has_nonelementwise_uses_except_dus(consumer->fused_parameter(i))) {
889 VLOG(4) << "Found a different elementwise";
890 return false;
891 }
892 }
893 return true;
894 }
895
896 } // namespace
897
ShouldFuseInPlaceOp(const HloInstruction * producer,const HloInstruction * consumer)898 /*static*/ FusionDecision InstructionFusion::ShouldFuseInPlaceOp(
899 const HloInstruction* producer, const HloInstruction* consumer) {
900 // Don't fuse if the producer is a non-elementwise op that has the same
901 // operand as an in-place operand of the consumer. The consumer will modify
902 // the buffer in-place, which will cause producer's operand to change if we
903 // allow them to fuse.
904 std::vector<std::pair<HloOperandIndex, ShapeIndex>>
905 in_place_input_output_pairs =
906 HloDataflowAnalysis::GetInPlaceInputOutputPairs(
907 const_cast<HloInstruction*>(consumer));
908 for (auto& pair : in_place_input_output_pairs) {
909 int operand_number = pair.first.operand_number;
910 VLOG(4) << "in/out pair: " << operand_number << " "
911 << pair.first.operand_index.ToString() << " "
912 << pair.second.ToString();
913 // Check if the consumer also has an additional operand that has the same
914 // value as the in-place buffer. If so, it's unsafe to fuse.
915 for (int i = 0; i < consumer->operand_count(); ++i) {
916 if (i != operand_number &&
917 consumer->operand(operand_number) == consumer->operand(i)) {
918 return "The consumer is an in-place operation that has an additional "
919 "operand that has the same value as the in-place buffer";
920 }
921 }
922 if (!producer->IsElementwise() &&
923 absl::c_find(producer->operands(), consumer->operand(operand_number)) !=
924 producer->operands().end()) {
925 VLOG(4) << "Found non-elementwise operand that uses the same operand of "
926 "an in-place consumer";
927 auto get_real_operand = [](const HloInstruction* op,
928 const HloInstruction* operand) {
929 if (op->opcode() == HloOpcode::kFusion &&
930 operand->opcode() == HloOpcode::kParameter) {
931 return op->operand(operand->parameter_number());
932 }
933 return operand;
934 };
935
936 auto get_constant_operand =
937 [](const HloInstruction* operand) -> std::optional<int> {
938 if (operand->IsConstant()) {
939 return operand->literal().GetFirstInteger();
940 }
941 return std::nullopt;
942 };
943 // A common special case is a slice or dynamic-slice and a
944 // dynamic-update-slice that use the same indices. This pattern is safe.
945 const HloInstruction* dus =
946 ExtractInstruction(consumer, HloOpcode::kDynamicUpdateSlice);
947 const HloInstruction* producer_nonelementwise =
948 ExtractInstruction(producer, [](const HloInstruction* inst) {
949 return inst->opcode() != HloOpcode::kFusion &&
950 !inst->IsElementwise();
951 });
952 if (dus == nullptr || producer_nonelementwise == nullptr ||
953 producer_nonelementwise->shape() != dus->operand(1)->shape()) {
954 return "Consumer is not a dus or the producer fusion has multiple "
955 "non-elementwise ops, bailing.";
956 }
957 if (producer_nonelementwise->opcode() == HloOpcode::kSlice) {
958 for (int i = 0; i < dus->shape().rank(); ++i) {
959 const HloInstruction* dus_operand =
960 get_real_operand(consumer, dus->operand(2 + i));
961 auto constant_operand = get_constant_operand(dus_operand);
962 if (!constant_operand ||
963 *constant_operand != producer_nonelementwise->slice_starts(i) ||
964 producer_nonelementwise->slice_strides(i) != 1) {
965 return "DUS and slice index mismatch";
966 }
967 }
968 VLOG(4) << "DUS and slice index match";
969 if (consumer->opcode() == HloOpcode::kFusion &&
970 !IsSafeToFuseSliceIntoDusFusion(producer, consumer)) {
971 return "Fusing slice into DUS will also fuse another non-elementwise "
972 "op with shared operand as DUS.";
973 }
974 return {};
975 }
976 if (producer_nonelementwise->opcode() == HloOpcode::kDynamicSlice) {
977 for (int i = 0; i < dus->shape().rank(); ++i) {
978 const HloInstruction* ds_operand = get_real_operand(
979 producer, producer_nonelementwise->operand(1 + i));
980 const HloInstruction* dus_operand =
981 get_real_operand(consumer, dus->operand(2 + i));
982 auto constant_ds_operand = get_constant_operand(ds_operand);
983 auto constant_dus_operand = get_constant_operand(dus_operand);
984 if (constant_ds_operand != constant_dus_operand ||
985 (!constant_ds_operand && ds_operand != dus_operand)) {
986 return "DUS and DS index mismatch";
987 }
988 }
989 VLOG(4) << "DUS and DS index match";
990 if (consumer->opcode() == HloOpcode::kFusion &&
991 !IsSafeToFuseSliceIntoDusFusion(producer, consumer)) {
992 return "Fusing DS into DUS will also fuse another non-elementwise op "
993 "with shared operand as DUS.";
994 }
995 return {};
996 }
997 return "unrecognized inplace update non-elementwise output pair";
998 }
999 }
1000 return {};
1001 }
1002
ShouldFuse(HloInstruction * consumer,int64_t operand_index)1003 FusionDecision InstructionFusion::ShouldFuse(HloInstruction* consumer,
1004 int64_t operand_index) {
1005 HloInstruction* producer = consumer->mutable_operand(operand_index);
1006
1007 // Don't fuse across a root instruction.
1008 if (producer == producer->parent()->root_instruction()) {
1009 return "not fusing into the output of the root instruction";
1010 }
1011
1012 // Cost condition: don't duplicate expensive instructions.
1013 if (FusionWouldDuplicate(*producer, *consumer) &&
1014 (!may_duplicate_ || is_expensive_(*producer)) &&
1015 !IsAlwaysDuplicable(*producer)) {
1016 return may_duplicate_ ? "expensive producer would be duplicated"
1017 : "fusion pass cannot duplicate";
1018 }
1019
1020 if (NoFusionPossible fusible = !ShouldFuseInPlaceOp(producer, consumer)) {
1021 return !fusible;
1022 }
1023
1024 return {};
1025 }
1026
ChooseKind(const HloInstruction * producer,const HloInstruction * consumer)1027 HloInstruction::FusionKind InstructionFusion::ChooseKind(
1028 const HloInstruction* producer, const HloInstruction* consumer) {
1029 return HloInstruction::FusionKind::kLoop;
1030 }
1031
ReusedOperandsOf(const HloInstruction * instruction)1032 absl::flat_hash_set<const HloInstruction*>& InstructionFusion::ReusedOperandsOf(
1033 const HloInstruction* instruction) {
1034 std::unique_ptr<absl::flat_hash_set<const HloInstruction*>>& reused_operands =
1035 reused_fusion_operands_[instruction];
1036 if (reused_operands != nullptr) {
1037 return *reused_operands;
1038 }
1039 reused_operands =
1040 std::make_unique<absl::flat_hash_set<const HloInstruction*>>();
1041
1042 for (int64_t i = 0; i < instruction->operand_count(); ++i) {
1043 bool reuses = instruction->ReusesOperandElements(i);
1044 if (reuses) {
1045 // We cache the operand corresponding to the fusion parameter, because the
1046 // parameter numbers would be invalidated after the next fusion.
1047 reused_operands->insert(instruction->operand(i));
1048 }
1049 }
1050 return *reused_operands;
1051 }
1052
ReusesOperandElements(const HloInstruction * consumer,int64_t operand_index)1053 bool InstructionFusion::ReusesOperandElements(const HloInstruction* consumer,
1054 int64_t operand_index) {
1055 auto operand = consumer->operand(operand_index);
1056 return ReusedOperandsOf(consumer).contains(operand);
1057 }
1058
1059 } // namespace xla
1060