• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/horizontal_input_fusion.h"
17 
18 #include <algorithm>
19 
20 #include "absl/container/flat_hash_set.h"
21 #include "absl/strings/str_cat.h"
22 #include "absl/types/span.h"
23 #include "tensorflow/compiler/xla/layout_util.h"
24 #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
25 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
26 #include "tensorflow/compiler/xla/service/hlo_creation_utils.h"
27 #include "tensorflow/core/platform/errors.h"
28 
29 namespace xla {
30 namespace gpu {
31 
32 namespace {
33 
34 // Gets the representative input shape of the multi-output fusion.
GetInputShapeForMultiOutputFusion(const HloInstruction & instr)35 Shape GetInputShapeForMultiOutputFusion(const HloInstruction& instr) {
36   // Get the HLO that determines the emitter used for lowering.
37   const HloInstruction* real_hero = GetRealHeroForMultiOutputFusion(instr);
38   if (real_hero->operands().empty()) {
39     // Simply return an empty shape if the representative node has no input
40     // operands.
41     return Shape();
42   } else {
43     return real_hero->operand(0)->shape();
44   }
45 }
46 
47 class HorizontalInputFusionImpl {
48  public:
HorizontalInputFusionImpl(HloComputation * computation)49   explicit HorizontalInputFusionImpl(HloComputation* computation)
50       : computation_(computation) {}
51 
~HorizontalInputFusionImpl()52   ~HorizontalInputFusionImpl() {}
53 
54   StatusOr<bool> Run();
55 
56  private:
57   HloComputation* computation_;
58 };  // HorizontalInputFusionImpl
59 
60 // Compares one-by-one the dimensions of `shape_a` and `shape_b` from left to
61 // right.
CompareShapeDimsFromLeftToRight(const Shape & shape_a,const Shape & shape_b)62 bool CompareShapeDimsFromLeftToRight(const Shape& shape_a,
63                                      const Shape& shape_b) {
64   if (shape_a.rank() != shape_b.rank()) {
65     return shape_a.rank() < shape_b.rank();
66   }
67   auto dims_a = shape_a.dimensions();
68   auto dims_b = shape_b.dimensions();
69   for (size_t i = 0; i < dims_a.size(); ++i) {
70     if (dims_a[i] != dims_b[i]) {
71       return dims_a[i] < dims_b[i];
72     }
73   }
74   return true;
75 }
76 
FindAndSortFusionCandidates(HloInstruction * consumer)77 std::vector<HloInstruction*> FindAndSortFusionCandidates(
78     HloInstruction* consumer) {
79   absl::flat_hash_set<HloInstruction*> fusion_instr_set;
80   std::vector<HloInstruction*> fusion_instrs;
81   for (HloInstruction* opnd : consumer->operands()) {
82     HloInstruction* predecessor = opnd->LatestNonGteAncestor();
83     // Find out the input fusion instructions whose only consumer is `consumer`.
84     // This guarantees that fusing these candidates will never create cycles, as
85     // there is no back edge.
86     if (IsInputFusibleReduction(*predecessor) &&
87         IsConsumerTheOnlyNonRootUser(*predecessor, *consumer)) {
88       if (fusion_instr_set.insert(predecessor).second) {
89         fusion_instrs.push_back(predecessor);
90       }
91     }
92   }
93 
94   std::sort(fusion_instrs.begin(), fusion_instrs.end(),
95             [&](const HloInstruction* a, const HloInstruction* b) {
96               Shape shape_a = GetInputShapeForMultiOutputFusion(*a);
97               Shape shape_b = GetInputShapeForMultiOutputFusion(*b);
98               if (!ShapeUtil::EqualIgnoringElementType(shape_a, shape_b)) {
99                 // Sort shapes according to dimensions, so that the same input
100                 // shapes will be placed adjacent each other.
101                 return CompareShapeDimsFromLeftToRight(shape_a, shape_b);
102               }
103               // Sort `fusion_instrs` according to instruction counts, because
104               // we'd like to fuse together computations of similar sizes.
105               return GetInstrCountOfFusible(*a) < GetInstrCountOfFusible(*b);
106             });
107 
108   return fusion_instrs;
109 }
110 
Run()111 StatusOr<bool> HorizontalInputFusionImpl::Run() {
112   bool changed = false;
113   XLA_VLOG_LINES(3, computation_->ToString());
114 
115   // Using def-to-use order is sound since we do not modify users.
116   std::vector<HloInstruction*> def_to_use_order =
117       computation_->MakeInstructionPostOrder();
118   for (HloInstruction* consumer : def_to_use_order) {
119     auto candidates = FindAndSortFusionCandidates(consumer);
120     if (candidates.size() <= 1) {
121       continue;
122     }
123 
124     // Convert candidates into fusions if needed.
125     for (size_t j = 0; j < candidates.size(); ++j) {
126       if (candidates[j]->opcode() != HloOpcode::kFusion) {
127         TF_ASSIGN_OR_RETURN(
128             HloInstruction * fusion_instr,
129             MakeFusionInstruction(candidates[j],
130                                   HloInstruction::FusionKind::kInput));
131         candidates[j] = fusion_instr;
132         changed = true;
133       }
134     }
135 
136     size_t fusion_anchor_id = 0;
137     for (size_t j = 1; j < candidates.size(); ++j) {
138       HloInstruction* fusion_anchor = candidates[fusion_anchor_id];
139       HloInstruction* fused = candidates[j];
140       if (ShapesCompatibleForMultiOutputFusion(*fusion_anchor, *fused) &&
141           FusionFitsInBudget(*fusion_anchor, *fused)) {
142         VLOG(3) << "Fuse " << fused->ToString() << " into "
143                 << fusion_anchor->ToString();
144         fusion_anchor->MergeFusionInstructionIntoMultiOutput(fused);
145         changed = true;
146       } else {
147         // Update the `fusion_anchor_id` since `fused` is either not
148         // compatible or not beneficial to be fused with current fusion anchor.
149         VLOG(3) << j - fusion_anchor_id - 1 << " instructions are fused.";
150         fusion_anchor_id = j;
151       }
152     }
153   }
154 
155   return changed;
156 }
157 
158 }  // namespace
159 
RunOnComputation(HloComputation * computation)160 StatusOr<bool> GpuHorizontalInputFusion::RunOnComputation(
161     HloComputation* computation) {
162   HorizontalInputFusionImpl horizontal_fusion_impl(computation);
163   return horizontal_fusion_impl.Run();
164 }
165 
Run(HloModule * module,const absl::flat_hash_set<absl::string_view> & execution_threads)166 StatusOr<bool> GpuHorizontalInputFusion::Run(
167     HloModule* module,
168     const absl::flat_hash_set<absl::string_view>& execution_threads) {
169   bool changed = false;
170   VLOG(2) << "Run horizontal input fusion.";
171   for (HloComputation* comp :
172        module->MakeNonfusionComputations(execution_threads)) {
173     TF_ASSIGN_OR_RETURN(changed, RunOnComputation(comp));
174   }
175 
176   return changed;
177 }
178 
179 }  // namespace gpu
180 }  // namespace xla
181