1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/horizontal_loop_fusion.h"
17
18 #include <algorithm>
19 #include <memory>
20 #include <utility>
21 #include <vector>
22
23 #include "absl/container/flat_hash_set.h"
24 #include "absl/types/span.h"
25 #include "tensorflow/compiler/xla/layout_util.h"
26 #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
27 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
28 #include "tensorflow/compiler/xla/xla_data.pb.h"
29 #include "tensorflow/core/platform/errors.h"
30
31 namespace xla {
32 namespace gpu {
33
34 namespace {
35
GetUniqueOutputTypeOfFusible(const HloInstruction & fusible)36 PrimitiveType GetUniqueOutputTypeOfFusible(const HloInstruction& fusible) {
37 auto outputs = GetOutputsOfFusible(fusible);
38 CHECK(!outputs.empty());
39 PrimitiveType first_output_type = outputs[0]->shape().element_type();
40 for (size_t i = 1; i < outputs.size(); ++i) {
41 PrimitiveType cur_output_type = outputs[i]->shape().element_type();
42 CHECK(first_output_type == cur_output_type)
43 << "Output types are expected to be unique, but see "
44 << PrimitiveType_Name(first_output_type) << " and "
45 << PrimitiveType_Name(cur_output_type);
46 }
47
48 return first_output_type;
49 }
50
51 class HorizontalLoopFusionImpl {
52 public:
HorizontalLoopFusionImpl(HloComputation * computation,absl::string_view prefix)53 explicit HorizontalLoopFusionImpl(HloComputation* computation,
54 absl::string_view prefix)
55 : computation_(computation), prefix_(prefix) {}
56
~HorizontalLoopFusionImpl()57 ~HorizontalLoopFusionImpl() {}
58
59 StatusOr<bool> Run();
60
61 private:
62 Status Fuse(absl::Span<HloInstruction*> fused_fusion_instrs);
63
64 // Horizontally fuses `fused_fusion_instrs`. It is required that each of
65 // `fused_fusion_instrs` is a kLoop fusion. Also, we require their numbers of
66 // outputs to be the same, so that each output will be fused/concatenated with
67 // the same number of outputs from other fused fusion instrs. Then, all the
68 // fused outputs still have the same shapes for kernel generation.
69 //
70 // Returns the fused computation in `uniq_computation` and the operands that
71 // are used by `uniq_computation`.
72 Status CreateFusedComputation(
73 absl::Span<HloInstruction*> fused_fusion_instrs,
74 std::unique_ptr<HloComputation>* uniq_computation,
75 std::vector<HloInstruction*>* bound_operands);
76
77 // FusionCandidates collects profitable candidates for a given consumer
78 // instruction. GetNextSpanOfFusions() can then be iteratively invoked to
79 // acquire the next set of fusion candidates based on some heuristics.
80 class FusionCandidates {
81 public:
FusionCandidates(HloInstruction * consumer)82 explicit FusionCandidates(HloInstruction* consumer)
83 : fusible_instrs_(), pos_(0) {
84 Initialize(consumer);
85 }
86
87 // Gets a span of fusions to be fused.
88 absl::Span<HloInstruction*> GetNextSpanOfFusions();
89
90 private:
91 void Initialize(HloInstruction*);
92
93 std::vector<HloInstruction*> fusible_instrs_;
94 // `pos_` points to the start position of the next span.
95 size_t pos_;
96 };
97
98 HloComputation* computation_;
99 std::string prefix_;
100 }; // HorizontalLoopFusionImpl
101
IsFusibleCandidate(const HloInstruction & instr)102 bool IsFusibleCandidate(const HloInstruction& instr) {
103 // Require no further check for element-wise instructions.
104 if (instr.IsElementwise() && instr.operand_count() > 0) {
105 return true;
106 }
107
108 // Exclude fusions other than kLoop.
109 if (!instr.IsLoopFusion()) {
110 return false;
111 }
112
113 // Cannot support fusion who has multiple output types, because the
114 // concatenate (inserted for horizontal fusion) requires the same type
115 // for all of its operands.
116 auto outputs = GetOutputsOfFusible(instr);
117 CHECK(!outputs.empty());
118 const HloInstruction* first_output = outputs[0];
119 for (size_t i = 1; i < outputs.size(); ++i) {
120 if (first_output->shape().element_type() !=
121 outputs[i]->shape().element_type()) {
122 return false;
123 }
124 }
125
126 return true;
127 }
128
129 // Returns whether `instr` is a profitable candidate to be horizontally fused.
130 // Since the primary benefit of horizontal fusion comes from reducing the
131 // kernel launch overhead, we want to exclude the instructions with
132 // insignificant kernel launch overhead. In other words, we exclude instructions
133 // if their computation latencies are longer than launch latencies. We estimate
134 // the computation latency of a given instruction by its shapes and the
135 // instruction count in its fused computation. We roughly observe that if a
136 // fusion instruction has shapes smaller than `kShapeThreshold` and has fewer
137 // instructions than `kInstrCountThreshold`, it is launch-latency-bound and
138 // profitable by horizontal fusion.
IsProfitableFusionCandidate(const HloInstruction & instr)139 bool IsProfitableFusionCandidate(const HloInstruction& instr) {
140 constexpr int64_t kShapeThreshold = 128 * 2048;
141 constexpr int64_t kInstrCountThreshold = 30;
142 const HloInstruction* root = (instr.opcode() == HloOpcode::kFusion)
143 ? instr.fused_expression_root()
144 : &instr;
145
146 // Too large shapes are not easily profitable.
147 if (root->opcode() == HloOpcode::kTuple) {
148 // Since all output shapes are the same, use the first shape as the
149 // representative.
150 Shape shape = root->operand(0)->shape();
151 if (ShapeUtil::ElementsIn(shape) > kShapeThreshold) {
152 return false;
153 }
154 } else {
155 Shape shape = root->shape();
156 if (ShapeUtil::ElementsIn(shape) > kShapeThreshold) {
157 return false;
158 }
159 }
160
161 // Having too many instructions is not easily profitable.
162 if (instr.opcode() == HloOpcode::kFusion &&
163 instr.fused_instruction_count() > kInstrCountThreshold) {
164 return false;
165 }
166
167 return true;
168 }
169
170 // Returns whether `fusion_instr` has only row-major layouts.
171 // The horizontal fusion excludes computations with non-row-major layouts,
172 // because fusing computations with different layouts can result in uncoalesced
173 // memory accesses and cause great performance overhead.
HasOnlyRowMajorLayout(const HloInstruction & instr)174 bool HasOnlyRowMajorLayout(const HloInstruction& instr) {
175 if (instr.opcode() != HloOpcode::kFusion) {
176 return LayoutUtil::IsMonotonicWithDim0Major(instr.shape().layout());
177 }
178
179 auto fused_instrs = instr.fused_instructions_computation()->instructions();
180 for (HloInstruction* i : fused_instrs) {
181 if (!LayoutUtil::IsDenseArray(i->shape())) {
182 continue;
183 }
184 if (!LayoutUtil::IsMonotonicWithDim0Major(i->shape().layout())) {
185 return false;
186 }
187 }
188 return true;
189 }
190
191 // Returns whether any operand of `instr` is a parameter instruction that
192 // is shared with `fusion_instrs`.
AnyOpndIsParamSharedAmongFusions(const HloInstruction * instr,const absl::flat_hash_set<HloInstruction * > & fusion_instrs)193 bool AnyOpndIsParamSharedAmongFusions(
194 const HloInstruction* instr,
195 const absl::flat_hash_set<HloInstruction*>& fusion_instrs) {
196 return absl::c_any_of(instr->operands(), [&](const HloInstruction* opnd) {
197 return opnd->opcode() == HloOpcode::kParameter &&
198 absl::c_any_of(opnd->users(), [&](const HloInstruction* user) {
199 return user != instr && fusion_instrs.contains(user);
200 });
201 });
202 }
203
Initialize(HloInstruction * consumer)204 void HorizontalLoopFusionImpl::FusionCandidates::Initialize(
205 HloInstruction* consumer) {
206 // First, find out all potential target candidates. We will filter out
207 // unsupported/non-profitable cases below.
208 absl::flat_hash_set<HloInstruction*> fusible_candidates;
209 std::vector<HloInstruction*> ordered_fusible_candidates;
210 for (HloInstruction* opnd : consumer->operands()) {
211 HloInstruction* predecessor = opnd->LatestNonGteAncestor();
212 // We support kLoop fusion and element-wise HLOs now. We may extend the
213 // support list if needs arise.
214 if (IsFusibleCandidate(*predecessor)) {
215 if (fusible_candidates.insert(predecessor).second) {
216 // Add unseen fusion to ordered list.
217 ordered_fusible_candidates.push_back(predecessor);
218 }
219 }
220 }
221
222 for (HloInstruction* instr : ordered_fusible_candidates) {
223 if (!IsConsumerTheOnlyNonRootUser(*instr, *consumer)) {
224 VLOG(2) << "Reject maybe illegal instr " << instr->ToString()
225 << "; including it may create cycles in HLO.";
226 continue;
227 } else if (!IsProfitableFusionCandidate(*instr)) {
228 VLOG(2) << "Reject may-not-be profitable fusion instr "
229 << instr->ToString();
230 continue;
231 } else if (!HasOnlyRowMajorLayout(*instr)) {
232 VLOG(2) << "Reject non-row-major fusion instr " << instr->ToString();
233 continue;
234 } else if (AnyOpndIsParamSharedAmongFusions(instr, fusible_candidates)) {
235 // Don't fuse fusions whose operands are parameter instructions that are
236 // shared among fusions because we cannot i/o alias the produced
237 // horizontal fusion due to the concat insertion.
238 VLOG(2) << "Reject the fusion instr because it shares parameter with"
239 << " other fusion candidates, instr: " << instr->ToString();
240 continue;
241 } else {
242 VLOG(2) << "Find a fusion candidate " << instr->ToString();
243 // Encapsulate it into a fusion computation for unified representation
244 // for later processing.
245 fusible_instrs_.push_back(instr);
246 }
247 }
248
249 // Sort `fusible_instrs_` according to output types, the number of outputs,
250 // and instruction counts, because we only fuse instructions with the same
251 // number/type of outputs and whose computations have the same instruction
252 // count.
253 std::sort(
254 fusible_instrs_.begin(), fusible_instrs_.end(),
255 [&](const HloInstruction* a, const HloInstruction* b) {
256 if (GetUniqueOutputTypeOfFusible(*a) !=
257 GetUniqueOutputTypeOfFusible(*b)) {
258 return GetUniqueOutputTypeOfFusible(*a) <
259 GetUniqueOutputTypeOfFusible(*b);
260 } else if (GetOutputSizeOfFusible(*a) != GetOutputSizeOfFusible(*b)) {
261 return GetOutputSizeOfFusible(*a) < GetOutputSizeOfFusible(*b);
262 } else {
263 return GetInstrCountOfFusible(*a) < GetInstrCountOfFusible(*b);
264 }
265 });
266 }
267
268 // Gets a next span of fusion instructions to be fused.
269 absl::Span<HloInstruction*>
GetNextSpanOfFusions()270 HorizontalLoopFusionImpl::FusionCandidates::GetNextSpanOfFusions() {
271 if (pos_ >= fusible_instrs_.size()) {
272 return absl::Span<HloInstruction*>();
273 }
274
275 // Fusing too many computations at a time may not be easily profitable and
276 // may increase compile time due to large kernels. Set a limit to it.
277 constexpr int64_t kMaxFusionBatchSize = 32;
278 // CUDA has a parameter size limit of ~4k bytes.
279 constexpr int64_t kMaxCudaParamSize = 4000;
280 size_t accum_io_size = 0;
281 auto reach_max_fusion_batch_size = [&](size_t left, size_t right) -> bool {
282 if (right - left >= kMaxFusionBatchSize) {
283 return true;
284 }
285
286 accum_io_size += fusible_instrs_.at(right)->operand_count() +
287 GetOutputSizeOfFusible(*fusible_instrs_.at(right));
288
289 if (accum_io_size * 8 >= kMaxCudaParamSize) {
290 return true;
291 }
292
293 return false;
294 };
295
296 size_t left = pos_;
297 size_t right = pos_ + 1;
298 size_t first_output_size = GetOutputSizeOfFusible(*fusible_instrs_[left]);
299 PrimitiveType first_output_type =
300 GetUniqueOutputTypeOfFusible(*fusible_instrs_[left]);
301 for (; right < fusible_instrs_.size(); ++right) {
302 PrimitiveType cur_output_type =
303 GetUniqueOutputTypeOfFusible(*fusible_instrs_[right]);
304 if (first_output_type != cur_output_type) {
305 // Cannot fuse computations who have multiple output types.
306 break;
307 } else if (first_output_size !=
308 GetOutputSizeOfFusible(*fusible_instrs_[right])) {
309 // Cannot fuse computations who have different numbers of outputs.
310 break;
311 } else if (GetInstrCountOfFusible(*fusible_instrs_[left]) !=
312 GetInstrCountOfFusible(*fusible_instrs_[right])) {
313 // Do not fuse computations of different instruction counts as it may
314 // introduce control divergence. This is a very simple heuristic to avoid
315 // fusing computations with too much discrepancy and we may improve it
316 // when the needs arise.
317 break;
318 } else if (reach_max_fusion_batch_size(left, right)) {
319 // Hit max fusion batch size.
320 break;
321 }
322 }
323
324 pos_ = right;
325 return absl::MakeSpan(fusible_instrs_).subspan(left, right - left);
326 }
327
CreateFusedComputation(absl::Span<HloInstruction * > fused_fusion_instrs,std::unique_ptr<HloComputation> * uniq_computation,std::vector<HloInstruction * > * bound_operands)328 Status HorizontalLoopFusionImpl::CreateFusedComputation(
329 absl::Span<HloInstruction*> fused_fusion_instrs,
330 std::unique_ptr<HloComputation>* uniq_computation,
331 std::vector<HloInstruction*>* bound_operands) {
332 // First, build a computation with only params.
333 HloComputation::Builder b(prefix_ + "horizontally_fused_computation");
334 size_t fused_comp_param_id = 0;
335 for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
336 auto old_params = fused_fusion_instrs[i]->fused_parameters();
337 for (size_t j = 0; j < old_params.size(); ++j) {
338 HloInstruction* bound_opnd = fused_fusion_instrs[i]->mutable_operand(j);
339 // in a form of param_i_j
340 b.AddInstruction(HloInstruction::CreateParameter(
341 fused_comp_param_id++, bound_opnd->shape(),
342 absl::StrCat("param_", i, "_", j)));
343 bound_operands->push_back(bound_opnd);
344 }
345 }
346 // Always create a dummy tuple instruction to serve as the root of the
347 // computation, as the existence of a root instruction is required by the
348 // HloComputation. The real root instruction will replace it below.
349 HloInstruction* dummy_root = b.AddInstruction(
350 HloInstruction::CreateTuple(std::vector<HloInstruction*>{}));
351 *uniq_computation = b.Build(dummy_root);
352 HloComputation* comp = uniq_computation->get();
353
354 // Preparing clone_map, which maps old operand to new operand.
355 absl::flat_hash_map<const HloInstruction*, HloInstruction*> clone_map;
356 size_t new_param_id = 0;
357 for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
358 auto old_params = fused_fusion_instrs[i]->fused_parameters();
359 for (size_t j = 0; j < old_params.size(); ++j) {
360 HloInstruction* old_param = old_params[j];
361 HloInstruction* new_param = comp->parameter_instruction(new_param_id++);
362 clone_map.insert({old_param, new_param});
363 }
364 }
365
366 // Clone every fused computation.
367 const OpMetadata* metadata = nullptr;
368 for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
369 auto def_to_use_order = fused_fusion_instrs[i]
370 ->fused_instructions_computation()
371 ->MakeInstructionPostOrder();
372 for (HloInstruction* old_instr : def_to_use_order) {
373 if (old_instr->opcode() == HloOpcode::kParameter) {
374 // Parameters have been created.
375 continue;
376 }
377 std::vector<HloInstruction*> new_opnds;
378 const auto& old_opnds = old_instr->operands();
379 new_opnds.reserve(old_opnds.size());
380 for (HloInstruction* old_opnd : old_opnds) {
381 CHECK(clone_map.find(old_opnd) != clone_map.end());
382 new_opnds.push_back(clone_map[old_opnd]);
383 }
384 HloInstruction* new_instr = comp->AddInstruction(
385 old_instr->CloneWithNewOperands(old_instr->shape(), new_opnds));
386 clone_map.insert({old_instr, new_instr});
387 // Get the metadata from the last fused instruction.
388 metadata = &old_instr->metadata();
389 }
390 }
391
392 std::vector<HloInstruction*> concated_outputs;
393 // Since we require each fusion to have the same number of outputs, we can
394 // simply use the first fusion as the representative for output size.
395 size_t fused_instr_output_size =
396 GetOutputSizeOfFusible(*fused_fusion_instrs[0]);
397 for (size_t i = 0; i < fused_instr_output_size; ++i) {
398 std::vector<HloInstruction*> reshapes(fused_fusion_instrs.size());
399 for (size_t j = 0; j < fused_fusion_instrs.size(); ++j) {
400 const HloInstruction* old_output =
401 GetOutputsOfFusible(*fused_fusion_instrs[j])[i];
402 HloInstruction* new_output = clone_map[old_output];
403 TF_ASSIGN_OR_RETURN(
404 reshapes[j],
405 MakeReshapeHlo(ShapeUtil::MakeShapeWithLayout(
406 new_output->shape().element_type(),
407 {ShapeUtil::ElementsIn(new_output->shape())},
408 /*minor_to_major=*/std::vector<int64_t>(1, 0)),
409 new_output));
410 }
411 TF_ASSIGN_OR_RETURN(HloInstruction * concated_output,
412 MakeConcatHlo(reshapes, 0));
413 concated_outputs.push_back(concated_output);
414 }
415
416 // Make slices of outputs.
417 std::vector<HloInstruction*> output_slices(concated_outputs.size() *
418 fused_fusion_instrs.size());
419 for (size_t i = 0; i < concated_outputs.size(); ++i) {
420 HloInstruction* concated_output = concated_outputs[i];
421 int64_t slice_start = 0;
422 // Create a slice per fused computation.
423 for (size_t j = 0; j < fused_fusion_instrs.size(); ++j) {
424 const HloInstruction* old_output =
425 GetOutputsOfFusible(*fused_fusion_instrs[j])[i];
426 Shape shape = old_output->shape();
427 int64_t slice_limit = slice_start + ShapeUtil::ElementsIn(shape);
428 TF_ASSIGN_OR_RETURN(
429 output_slices[concated_outputs.size() * j + i],
430 MakeSliceHlo(concated_output, {slice_start}, {slice_limit},
431 /*strides=*/{1}));
432 slice_start = slice_limit;
433 }
434 }
435
436 // Make a tuple of output_slices.
437 HloInstruction* tuple = comp->AddInstruction(
438 HloInstruction::CreateTuple(output_slices), metadata);
439 comp->set_root_instruction(tuple, /*accept_different_shape=*/true);
440 TF_RETURN_IF_ERROR(comp->RemoveInstruction(dummy_root));
441
442 return OkStatus();
443 }
444
Fuse(absl::Span<HloInstruction * > fused_fusion_instrs)445 Status HorizontalLoopFusionImpl::Fuse(
446 absl::Span<HloInstruction*> fused_fusion_instrs) {
447 // Fuse fused_fusion_instrs and replace them with the new fused computation.
448 std::unique_ptr<HloComputation> uniq_computation;
449 std::vector<HloInstruction*> bound_operands;
450 TF_RETURN_IF_ERROR(CreateFusedComputation(
451 fused_fusion_instrs, &uniq_computation, &bound_operands));
452 HloComputation* fused_comp = computation_->parent()->AddEmbeddedComputation(
453 std::move(uniq_computation));
454 HloInstruction* hori_fusion_instr = computation_->AddInstruction(
455 HloInstruction::CreateFusion(fused_comp->root_instruction()->shape(),
456 HloInstruction::FusionKind::kInput,
457 bound_operands, fused_comp, prefix_),
458 &fused_comp->root_instruction()->metadata());
459 fused_comp->SetFusionInstruction(hori_fusion_instr);
460
461 // Insert bitcasts and replace corresponding users. Note that we do not insert
462 // the bitcasts in the fused computation as it does not fit into the slice
463 // input fusion pattern. However, inserting bitcasts outside the fused
464 // computation creates no performance cost.
465 size_t total_output_id = 0;
466 for (size_t i = 0; i < fused_fusion_instrs.size(); ++i) {
467 std::vector<HloInstruction*> bitcasts;
468 HloInstruction* fused_instr = fused_fusion_instrs[i];
469 size_t num_outputs = GetOutputSizeOfFusible(*fused_instr);
470 for (size_t j = 0; j < num_outputs; ++j) {
471 const HloInstruction* output = GetOutputsOfFusible(*fused_instr)[j];
472 TF_ASSIGN_OR_RETURN(
473 HloInstruction * gep,
474 MakeGetTupleElementHlo(hori_fusion_instr, total_output_id++));
475 bitcasts.push_back(computation_->AddInstruction(
476 HloInstruction::CreateBitcast(output->shape(), gep)));
477 }
478 HloInstruction* bitcast_or_tuple =
479 (bitcasts.size() == 1) ? bitcasts.at(0)
480 : computation_->AddInstruction(
481 HloInstruction::CreateTuple(bitcasts));
482 TF_RETURN_IF_ERROR(
483 computation_->ReplaceInstruction(fused_instr, bitcast_or_tuple));
484 }
485
486 return OkStatus();
487 }
488
Run()489 StatusOr<bool> HorizontalLoopFusionImpl::Run() {
490 bool changed = false;
491 XLA_VLOG_LINES(3, computation_->ToString());
492
493 // Traverse from use to def. Bitcasts are placed after h-fusions to resolve
494 // shape mismatch but bitcasts could prevent future h-fusion from happening.
495 // So, a bottom-up, use-to-def order should be more favorable. It also helps
496 // to save compiler iterations to reach the fixed point.
497 std::vector<HloInstruction*> use_to_def_order =
498 computation_->MakeInstructionPostOrder();
499 absl::c_reverse(use_to_def_order);
500 for (size_t i = 0; i < use_to_def_order.size(); ++i) {
501 HloInstruction* consumer = use_to_def_order[i];
502 HorizontalLoopFusionImpl::FusionCandidates fusion_candidates(consumer);
503 while (true) {
504 auto fusibles = fusion_candidates.GetNextSpanOfFusions();
505 if (fusibles.empty()) {
506 break;
507 } else if (fusibles.size() == 1) {
508 // Skip; there is just one fused_instr.
509 continue;
510 }
511
512 changed = true;
513 // Convert fusible into fusion_instrs to simplify the implementation of
514 // `Fuse()`.
515 std::vector<HloInstruction*> fusion_instrs;
516 for (HloInstruction* instr : fusibles) {
517 if (instr->opcode() == HloOpcode::kFusion) {
518 fusion_instrs.push_back(instr);
519 } else {
520 TF_ASSIGN_OR_RETURN(
521 HloInstruction * fusion_instr,
522 MakeFusionInstruction(instr, HloInstruction::FusionKind::kLoop));
523 fusion_instrs.push_back(fusion_instr);
524 }
525 }
526 TF_RETURN_IF_ERROR(Fuse(absl::MakeSpan(fusion_instrs)));
527 }
528 }
529
530 return changed;
531 }
532
533 } // namespace
534
RunOnComputation(HloComputation * computation)535 StatusOr<bool> GpuHorizontalLoopFusion::RunOnComputation(
536 HloComputation* computation) {
537 HorizontalLoopFusionImpl horizontal_fusion_impl(computation, prefix_);
538 return horizontal_fusion_impl.Run();
539 }
540
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)541 StatusOr<bool> GpuHorizontalLoopFusion::Run(
542 HloModule* module,
543 const absl::flat_hash_set<absl::string_view>& execution_threads) {
544 bool changed = false;
545 VLOG(2) << "Run horizontal fusion.";
546
547 // Run on the entry computation is actually enough.
548 TF_ASSIGN_OR_RETURN(changed, RunOnComputation(module->entry_computation()));
549
550 return changed;
551 }
552
553 } // namespace gpu
554 } // namespace xla
555