• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h"
17 
18 #include <algorithm>
19 #include <memory>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/layout_util.h"
26 #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
27 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29 #include "tensorflow/core/platform/errors.h"
30 
31 namespace xla {
32 namespace gpu {
33 
34 namespace {
35 
GetUniqueOutputTypeOfFusible(const HloInstruction & fusible)36 PrimitiveType GetUniqueOutputTypeOfFusible(const HloInstruction& fusible) {
37   auto outputs = GetOutputsOfFusible(fusible);
38   CHECK(!outputs.empty());
39   PrimitiveType first_output_type = outputs[0]->shape().element_type();
40   for (size_t i = 1; i < outputs.size(); ++i) {
41     PrimitiveType cur_output_type = outputs[i]->shape().element_type();
42     CHECK(first_output_type == cur_output_type)
43         << "Output types are expected to be unique, but see "
44         << PrimitiveType_Name(first_output_type) << " and "
45         << PrimitiveType_Name(cur_output_type);
46   }
47 
48   return first_output_type;
49 }
50 
51 class HorizontalLoopFusionImpl {
52  public:
HorizontalLoopFusionImpl(HloComputation * computation,absl::string_view prefix)53   explicit HorizontalLoopFusionImpl(HloComputation* computation,
54                                     absl::string_view prefix)
55       : computation_(computation), prefix_(prefix) {}
56 
~HorizontalLoopFusionImpl()57   ~HorizontalLoopFusionImpl() {}
58 
59   StatusOr<bool> Run();
60 
61  private:
62   Status Fuse(absl::Span<HloInstruction*> fused_fusion_instrs);
63 
64   // Horizontally fuses `fused_fusion_instrs`. It is required that each of
65   // `fused_fusion_instrs` is a kLoop fusion. Also, we require their numbers of
66   // outputs to be the same, so that each output will be fused/concatenated with
67   // the same number of outputs from other fused fusion instrs. Then, all the
68   // fused outputs still have the same shapes for kernel generation.
69   //
70   // Returns the fused computation in `uniq_computation` and the operands that
71   // are used by `uniq_computation`.
72   Status CreateFusedComputation(
73       absl::Span<HloInstruction*> fused_fusion_instrs,
74       std::unique_ptr<HloComputation>* uniq_computation,
75       std::vector<HloInstruction*>* bound_operands);
76 
77   // FusionCandidates collects profitable candidates for a given consumer
78   // instruction. GetNextSpanOfFusions() can then be iteratively invoked to
79   // acquire the next set of fusion candidates based on some heuristics.
80   class FusionCandidates {
81    public:
FusionCandidates(HloInstruction * consumer)82     explicit FusionCandidates(HloInstruction* consumer)
83         : fusible_instrs_(), pos_(0) {
84       Initialize(consumer);
85     }
86 
87     // Gets a span of fusions to be fused.
88     absl::Span<HloInstruction*> GetNextSpanOfFusions();
89 
90    private:
91     void Initialize(HloInstruction*);
92 
93     std::vector<HloInstruction*> fusible_instrs_;
94     // `pos_` points to the start position of the next span.
95     size_t pos_;
96   };
97 
98   HloComputation* computation_;
99   std::string prefix_;
100 };  // HorizontalLoopFusionImpl
101 
IsFusibleCandidate(const HloInstruction & instr)102 bool IsFusibleCandidate(const HloInstruction& instr) {
103   // Require no further check for element-wise instructions.
104   if (instr.IsElementwise() && instr.operand_count() > 0) {
105     return true;
106   }
107 
108   // Exclude fusions other than kLoop.
109   if (!instr.IsLoopFusion()) {
110     return false;
111   }
112 
113   // Cannot support fusion who has multiple output types, because the
114   // concatenate (inserted for horizontal fusion) requires the same type
115   // for all of its operands.
116   auto outputs = GetOutputsOfFusible(instr);
117   CHECK(!outputs.empty());
118   const HloInstruction* first_output = outputs[0];
119   for (size_t i = 1; i < outputs.size(); ++i) {
120     if (first_output->shape().element_type() !=
121         outputs[i]->shape().element_type()) {
122       return false;
123     }
124   }
125 
126   return true;
127 }
128 
129 // Returns whether `instr` is a profitable candidate to be horizontally fused.
130 // Since the primary benefit of horizontal fusion comes from reducing the
131 // kernel launch overhead, we want to exclude the instructions with
132 // insignificant kernel launch overhead. In other words, we exclude instructions
133 // if their computation latencies are longer than launch latencies. We estimate
134 // the computation latency of a given instruction by its shapes and the
135 // instruction count in its fused computation. We roughly observe that if a
136 // fusion instruction has shapes smaller than `kShapeThreshold` and has fewer
137 // instructions than `kInstrCountThreshold`, it is launch-latency-bound and
138 // profitable by horizontal fusion.
IsProfitableFusionCandidate(const HloInstruction & instr)139 bool IsProfitableFusionCandidate(const HloInstruction& instr) {
140   constexpr int64_t kShapeThreshold = 128 * 2048;
141   constexpr int64_t kInstrCountThreshold = 30;
142   const HloInstruction* root = (instr.opcode() == HloOpcode::kFusion)
143                                    ? instr.fused_expression_root()
144                                    : &instr;
145 
146   // Too large shapes are not easily profitable.
147   if (root->opcode() == HloOpcode::kTuple) {
148     // Since all output shapes are the same, use the first shape as the
149     // representative.
150     Shape shape = root->operand(0)->shape();
151     if (ShapeUtil::ElementsIn(shape) > kShapeThreshold) {
152       return false;
153     }
154   } else {
155     Shape shape = root->shape();
156     if (ShapeUtil::ElementsIn(shape) > kShapeThreshold) {
157       return false;
158     }
159   }
160 
161   // Having too many instructions is not easily profitable.
162   if (instr.opcode() == HloOpcode::kFusion &&
163       instr.fused_instruction_count() > kInstrCountThreshold) {
164     return false;
165   }
166 
167   return true;
168 }
169 
170 // Returns whether `fusion_instr` has only row-major layouts.
171 // The horizontal fusion excludes computations with non-row-major layouts,
172 // because fusing computations with different layouts can result in uncoalesced
173 // memory accesses and cause great performance overhead.
HasOnlyRowMajorLayout(const HloInstruction & instr)174 bool HasOnlyRowMajorLayout(const HloInstruction& instr) {
175   if (instr.opcode() != HloOpcode::kFusion) {
176     return LayoutUtil::IsMonotonicWithDim0Major(instr.shape().layout());
177   }
178 
179   auto fused_instrs = instr.fused_instructions_computation()->instructions();
180   for (HloInstruction* i : fused_instrs) {
181     if (!LayoutUtil::IsDenseArray(i->shape())) {
182       continue;
183     }
184     if (!LayoutUtil::IsMonotonicWithDim0Major(i->shape().layout())) {
185       return false;
186     }
187   }
188   return true;
189 }
190 
191 // Returns whether any operand of `instr` is a parameter instruction that
192 // is shared with `fusion_instrs`.
AnyOpndIsParamSharedAmongFusions(const HloInstruction * instr,const absl::flat_hash_set<HloInstruction * > & fusion_instrs)193 bool AnyOpndIsParamSharedAmongFusions(
194     const HloInstruction* instr,
195     const absl::flat_hash_set<HloInstruction*>& fusion_instrs) {
196   return absl::c_any_of(instr->operands(), [&](const HloInstruction* opnd) {
197     return opnd->opcode() == HloOpcode::kParameter &&
198            absl::c_any_of(opnd->users(), [&](const HloInstruction* user) {
199              return user != instr && fusion_instrs.contains(user);
200            });
201   });
202 }
203 
Initialize(HloInstruction * consumer)204 void HorizontalLoopFusionImpl::FusionCandidates::Initialize(
205     HloInstruction* consumer) {
206   // First, find out all potential target candidates. We will filter out
207   // unsupported/non-profitable cases below.
208   absl::flat_hash_set<HloInstruction*> fusible_candidates;
209   std::vector<HloInstruction*> ordered_fusible_candidates;
210   for (HloInstruction* opnd : consumer->operands()) {
211     HloInstruction* predecessor = opnd->LatestNonGteAncestor();
212     // We support kLoop fusion and element-wise HLOs now. We may extend the
213     // support list if needs arise.
214     if (IsFusibleCandidate(*predecessor)) {
215       if (fusible_candidates.insert(predecessor).second) {
216         // Add unseen fusion to ordered list.
217         ordered_fusible_candidates.push_back(predecessor);
218       }
219     }
220   }
221 
222   for (HloInstruction* instr : ordered_fusible_candidates) {
223     if (!IsConsumerTheOnlyNonRootUser(*instr, *consumer)) {
224       VLOG(2) << "Reject maybe illegal instr " << instr->ToString()
225               << "; including it may create cycles in HLO.";
226       continue;
227     } else if (!IsProfitableFusionCandidate(*instr)) {
228       VLOG(2) << "Reject may-not-be profitable fusion instr "
229               << instr->ToString();
230       continue;
231     } else if (!HasOnlyRowMajorLayout(*instr)) {
232       VLOG(2) << "Reject non-row-major fusion instr " << instr->ToString();
233       continue;
234     } else if (AnyOpndIsParamSharedAmongFusions(instr, fusible_candidates)) {
235       // Don't fuse fusions whose operands are parameter instructions that are
236       // shared among fusions because we cannot i/o alias the produced
237       // horizontal fusion due to the concat insertion.
238       VLOG(2) << "Reject the fusion instr because it shares parameter with"
239               << " other fusion candidates, instr: " << instr->ToString();
240       continue;
241     } else {
242       VLOG(2) << "Find a fusion candidate " << instr->ToString();
243       // Encapsulate it into a fusion computation for unified representation
244       // for later processing.
245       fusible_instrs_.push_back(instr);
246     }
247   }
248 
249   // Sort `fusible_instrs_` according to output types, the number of outputs,
250   // and instruction counts, because we only fuse instructions with the same
251   // number/type of outputs and whose computations have the same instruction
252   // count.
253   std::sort(
254       fusible_instrs_.begin(), fusible_instrs_.end(),
255       [&](const HloInstruction* a, const HloInstruction* b) {
256         if (GetUniqueOutputTypeOfFusible(*a) !=
257             GetUniqueOutputTypeOfFusible(*b)) {
258           return GetUniqueOutputTypeOfFusible(*a) <
259                  GetUniqueOutputTypeOfFusible(*b);
260         } else if (GetOutputSizeOfFusible(*a) != GetOutputSizeOfFusible(*b)) {
261           return GetOutputSizeOfFusible(*a) < GetOutputSizeOfFusible(*b);
262         } else {
263           return GetInstrCountOfFusible(*a) < GetInstrCountOfFusible(*b);
264         }
265       });
266 }
267 
268 // Gets a next span of fusion instructions to be fused.
269 absl::Span<HloInstruction*>
GetNextSpanOfFusions()270 HorizontalLoopFusionImpl::FusionCandidates::GetNextSpanOfFusions() {
271   if (pos_ >= fusible_instrs_.size()) {
272     return absl::Span<HloInstruction*>();
273   }
274 
275   // Fusing too many computations at a time may not be easily profitable and
276   // may increase compile time due to large kernels. Set a limit to it.
277   constexpr int64_t kMaxFusionBatchSize = 32;
278   // CUDA has a parameter size limit of ~4k bytes.
279   constexpr int64_t kMaxCudaParamSize = 4000;
280   size_t accum_io_size = 0;
281   auto reach_max_fusion_batch_size = [&](size_t left, size_t right) -> bool {
282     if (right - left >= kMaxFusionBatchSize) {
283       return true;
284     }
285 
286     accum_io_size += fusible_instrs_.at(right)->operand_count() +
287                      GetOutputSizeOfFusible(*fusible_instrs_.at(right));
288 
289     if (accum_io_size * 8 >= kMaxCudaParamSize) {
290       return true;
291     }
292 
293     return false;
294   };
295 
296   size_t left = pos_;
297   size_t right = pos_ + 1;
298   size_t first_output_size = GetOutputSizeOfFusible(*fusible_instrs_[left]);
299   PrimitiveType first_output_type =
300       GetUniqueOutputTypeOfFusible(*fusible_instrs_[left]);
301   for (; right < fusible_instrs_.size(); ++right) {
302     PrimitiveType cur_output_type =
303         GetUniqueOutputTypeOfFusible(*fusible_instrs_[right]);
304     if (first_output_type != cur_output_type) {
305       // Cannot fuse computations who have multiple output types.
306       break;
307     } else if (first_output_size !=
308                GetOutputSizeOfFusible(*fusible_instrs_[right])) {
309       // Cannot fuse computations who have different numbers of outputs.
310       break;
311     } else if (GetInstrCountOfFusible(*fusible_instrs_[left]) !=
312                GetInstrCountOfFusible(*fusible_instrs_[right])) {
313       // Do not fuse computations of different instruction counts as it may
314       // introduce control divergence. This is a very simple heuristic to avoid
315       // fusing computations with too much discrepancy and we may improve it
316       // when the needs arise.
317       break;
318     } else if (reach_max_fusion_batch_size(left, right)) {
319       // Hit max fusion batch size.
320       break;
321     }
322   }
323 
324   pos_ = right;
325   return absl::MakeSpan(fusible_instrs_).subspan(left, right - left);
326 }
327 
CreateFusedComputation(absl::Span<HloInstruction * > fused_fusion_instrs,std::unique_ptr<HloComputation> * uniq_computation,std::vector<HloInstruction * > * bound_operands)328 Status HorizontalLoopFusionImpl::CreateFusedComputation(
329     absl::Span<HloInstruction*> fused_fusion_instrs,
330     std::unique_ptr<HloComputation>* uniq_computation,
331     std::vector<HloInstruction*>* bound_operands) {
332   // First, build a computation with only params.
333   HloComputation::Builder b(prefix_ + "horizontally_fused_computation");
334   size_t fused_comp_param_id = 0;
335   for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
336     auto old_params = fused_fusion_instrs[i]->fused_parameters();
337     for (size_t j = 0; j < old_params.size(); ++j) {
338       HloInstruction* bound_opnd = fused_fusion_instrs[i]->mutable_operand(j);
339       // in a form of param_i_j
340       b.AddInstruction(HloInstruction::CreateParameter(
341           fused_comp_param_id++, bound_opnd->shape(),
342           absl::StrCat("param_", i, "_", j)));
343       bound_operands->push_back(bound_opnd);
344     }
345   }
346   // Always create a dummy tuple instruction to serve as the root of the
347   // computation, as the existence of a root instruction is required by the
348   // HloComputation. The real root instruction will replace it below.
349   HloInstruction* dummy_root = b.AddInstruction(
350       HloInstruction::CreateTuple(std::vector<HloInstruction*>{}));
351   *uniq_computation = b.Build(dummy_root);
352   HloComputation* comp = uniq_computation->get();
353 
354   // Preparing clone_map, which maps old operand to new operand.
355   absl::flat_hash_map<const HloInstruction*, HloInstruction*> clone_map;
356   size_t new_param_id = 0;
357   for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
358     auto old_params = fused_fusion_instrs[i]->fused_parameters();
359     for (size_t j = 0; j < old_params.size(); ++j) {
360       HloInstruction* old_param = old_params[j];
361       HloInstruction* new_param = comp->parameter_instruction(new_param_id++);
362       clone_map.insert({old_param, new_param});
363     }
364   }
365 
366   // Clone every fused computation.
367   const OpMetadata* metadata = nullptr;
368   for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
369     auto def_to_use_order = fused_fusion_instrs[i]
370                                 ->fused_instructions_computation()
371                                 ->MakeInstructionPostOrder();
372     for (HloInstruction* old_instr : def_to_use_order) {
373       if (old_instr->opcode() == HloOpcode::kParameter) {
374         // Parameters have been created.
375         continue;
376       }
377       std::vector<HloInstruction*> new_opnds;
378       const auto& old_opnds = old_instr->operands();
379       new_opnds.reserve(old_opnds.size());
380       for (HloInstruction* old_opnd : old_opnds) {
381         CHECK(clone_map.find(old_opnd) != clone_map.end());
382         new_opnds.push_back(clone_map[old_opnd]);
383       }
384       HloInstruction* new_instr = comp->AddInstruction(
385           old_instr->CloneWithNewOperands(old_instr->shape(), new_opnds));
386       clone_map.insert({old_instr, new_instr});
387       // Get the metadata from the last fused instruction.
388       metadata = &old_instr->metadata();
389     }
390   }
391 
392   std::vector<HloInstruction*> concated_outputs;
393   // Since we require each fusion to have the same number of outputs, we can
394   // simply use the first fusion as the representative for output size.
395   size_t fused_instr_output_size =
396       GetOutputSizeOfFusible(*fused_fusion_instrs[0]);
397   for (size_t i = 0; i < fused_instr_output_size; ++i) {
398     std::vector<HloInstruction*> reshapes(fused_fusion_instrs.size());
399     for (size_t j = 0; j < fused_fusion_instrs.size(); ++j) {
400       const HloInstruction* old_output =
401           GetOutputsOfFusible(*fused_fusion_instrs[j])[i];
402       HloInstruction* new_output = clone_map[old_output];
403       TF_ASSIGN_OR_RETURN(
404           reshapes[j],
405           MakeReshapeHlo(ShapeUtil::MakeShapeWithLayout(
406                              new_output->shape().element_type(),
407                              {ShapeUtil::ElementsIn(new_output->shape())},
408                              /*minor_to_major=*/std::vector<int64_t>(1, 0)),
409                          new_output));
410     }
411     TF_ASSIGN_OR_RETURN(HloInstruction * concated_output,
412                         MakeConcatHlo(reshapes, 0));
413     concated_outputs.push_back(concated_output);
414   }
415 
416   // Make slices of outputs.
417   std::vector<HloInstruction*> output_slices(concated_outputs.size() *
418                                              fused_fusion_instrs.size());
419   for (size_t i = 0; i < concated_outputs.size(); ++i) {
420     HloInstruction* concated_output = concated_outputs[i];
421     int64_t slice_start = 0;
422     // Create a slice per fused computation.
423     for (size_t j = 0; j < fused_fusion_instrs.size(); ++j) {
424       const HloInstruction* old_output =
425           GetOutputsOfFusible(*fused_fusion_instrs[j])[i];
426       Shape shape = old_output->shape();
427       int64_t slice_limit = slice_start + ShapeUtil::ElementsIn(shape);
428       TF_ASSIGN_OR_RETURN(
429           output_slices[concated_outputs.size() * j + i],
430           MakeSliceHlo(concated_output, {slice_start}, {slice_limit},
431                        /*strides=*/{1}));
432       slice_start = slice_limit;
433     }
434   }
435 
436   // Make a tuple of output_slices.
437   HloInstruction* tuple = comp->AddInstruction(
438       HloInstruction::CreateTuple(output_slices), metadata);
439   comp->set_root_instruction(tuple, /*accept_different_shape=*/true);
440   TF_RETURN_IF_ERROR(comp->RemoveInstruction(dummy_root));
441 
442   return OkStatus();
443 }
444 
Fuse(absl::Span<HloInstruction * > fused_fusion_instrs)445 Status HorizontalLoopFusionImpl::Fuse(
446     absl::Span<HloInstruction*> fused_fusion_instrs) {
447   // Fuse fused_fusion_instrs and replace them with the new fused computation.
448   std::unique_ptr<HloComputation> uniq_computation;
449   std::vector<HloInstruction*> bound_operands;
450   TF_RETURN_IF_ERROR(CreateFusedComputation(
451       fused_fusion_instrs, &uniq_computation, &bound_operands));
452   HloComputation* fused_comp = computation_->parent()->AddEmbeddedComputation(
453       std::move(uniq_computation));
454   HloInstruction* hori_fusion_instr = computation_->AddInstruction(
455       HloInstruction::CreateFusion(fused_comp->root_instruction()->shape(),
456                                    HloInstruction::FusionKind::kInput,
457                                    bound_operands, fused_comp, prefix_),
458       &fused_comp->root_instruction()->metadata());
459   fused_comp->SetFusionInstruction(hori_fusion_instr);
460 
461   // Insert bitcasts and replace corresponding users. Note that we do not insert
462   // the bitcasts in the fused computation as it does not fit into the slice
463   // input fusion pattern. However, inserting bitcasts outside the fused
464   // computation creates no performance cost.
465   size_t total_output_id = 0;
466   for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
467     std::vector<HloInstruction*> bitcasts;
468     HloInstruction* fused_instr = fused_fusion_instrs[i];
469     size_t num_outputs = GetOutputSizeOfFusible(*fused_instr);
470     for (size_t j = 0; j < num_outputs; ++j) {
471       const HloInstruction* output = GetOutputsOfFusible(*fused_instr)[j];
472       TF_ASSIGN_OR_RETURN(
473           HloInstruction * gep,
474           MakeGetTupleElementHlo(hori_fusion_instr, total_output_id++));
475       bitcasts.push_back(computation_->AddInstruction(
476           HloInstruction::CreateBitcast(output->shape(), gep)));
477     }
478     HloInstruction* bitcast_or_tuple =
479         (bitcasts.size() == 1) ? bitcasts.at(0)
480                                : computation_->AddInstruction(
481                                      HloInstruction::CreateTuple(bitcasts));
482     TF_RETURN_IF_ERROR(
483         computation_->ReplaceInstruction(fused_instr, bitcast_or_tuple));
484   }
485 
486   return OkStatus();
487 }
488 
Run()489 StatusOr<bool> HorizontalLoopFusionImpl::Run() {
490   bool changed = false;
491   XLA_VLOG_LINES(3, computation_->ToString());
492 
493   // Traverse from use to def. Bitcasts are placed after h-fusions to resolve
494   // shape mismatch but bitcasts could prevent future h-fusion from happening.
495   // So, a bottom-up, use-to-def order should be more favorable. It also helps
496   // to save compiler iterations to reach the fixed point.
497   std::vector<HloInstruction*> use_to_def_order =
498       computation_->MakeInstructionPostOrder();
499   absl::c_reverse(use_to_def_order);
500   for (size_t i = 0; i < use_to_def_order.size(); ++i) {
501     HloInstruction* consumer = use_to_def_order[i];
502     HorizontalLoopFusionImpl::FusionCandidates fusion_candidates(consumer);
503     while (true) {
504       auto fusibles = fusion_candidates.GetNextSpanOfFusions();
505       if (fusibles.empty()) {
506         break;
507       } else if (fusibles.size() == 1) {
508         // Skip; there is just one fused_instr.
509         continue;
510       }
511 
512       changed = true;
513       // Convert fusible into fusion_instrs to simplify the implementation of
514       // `Fuse()`.
515       std::vector<HloInstruction*> fusion_instrs;
516       for (HloInstruction* instr : fusibles) {
517         if (instr->opcode() == HloOpcode::kFusion) {
518           fusion_instrs.push_back(instr);
519         } else {
520           TF_ASSIGN_OR_RETURN(
521               HloInstruction * fusion_instr,
522               MakeFusionInstruction(instr, HloInstruction::FusionKind::kLoop));
523           fusion_instrs.push_back(fusion_instr);
524         }
525       }
526       TF_RETURN_IF_ERROR(Fuse(absl::MakeSpan(fusion_instrs)));
527     }
528   }
529 
530   return changed;
531 }
532 
533 }  // namespace
534 
RunOnComputation(HloComputation * computation)535 StatusOr<bool> GpuHorizontalLoopFusion::RunOnComputation(
536     HloComputation* computation) {
537   HorizontalLoopFusionImpl horizontal_fusion_impl(computation, prefix_);
538   return horizontal_fusion_impl.Run();
539 }
540 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)541 StatusOr<bool> GpuHorizontalLoopFusion::Run(
542     HloModule* module,
543     const absl::flat_hash_set<absl::string_view>& execution_threads) {
544   bool changed = false;
545   VLOG(2) << "Run horizontal fusion.";
546 
547   // Run on the entry computation is actually enough.
548   TF_ASSIGN_OR_RETURN(changed, RunOnComputation(module->entry_computation()));
549 
550   return changed;
551 }
552 
553 }  // namespace gpu
554 }  // namespace xla
555