1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
17
18 #include "absl/container/flat_hash_set.h"
19 #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
20 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
21 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
22 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
23 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26
27 namespace xla {
28 namespace gpu {
29
30 namespace {
31
IsIEEEFloatingPointScalarConstant(const HloInstruction * constant)32 bool IsIEEEFloatingPointScalarConstant(const HloInstruction* constant) {
33 if (constant->opcode() != HloOpcode::kConstant ||
34 !ShapeUtil::IsScalar(constant->shape())) {
35 return false;
36 }
37 auto type = constant->shape().element_type();
38 return type == F16 || type == F32 || type == F64;
39 }
40
41 } // namespace
42
IsExpensive(const HloInstruction & instruction)43 /*static*/ bool GpuInstructionFusion::IsExpensive(
44 const HloInstruction& instruction) {
45 switch (instruction.opcode()) {
46 // We say that floating-point division is cheap on the GPU.
47 case HloOpcode::kDivide:
48 return !ShapeUtil::ElementIsFloating(instruction.shape()) &&
49 InstructionFusion::IsExpensive(instruction);
50
51 default:
52 return InstructionFusion::IsExpensive(instruction);
53 }
54 }
55
56 // This function limits the maximum number of operands to a fusion.
57 //
58 // There's a cap on how many parameters we can pass to a CUDA kernel, but
59 // exactly what that limit is hazy, as it depends on (among other things) how
60 // much GPU constant memory is in use for other purposes.
61 //
62 // Moreover, we don't even know at the point that we're running fusion how many
63 // arguments the CUDA kernel for a fusion node will have: It depends on buffer
64 // assignment, where we will decide which of the fusion's operands live in XLA's
65 // big temp buffer versus in other allocations.
66 //
67 // As a heuristic, we simply cap the number of fusion operands plus outputs at
68 // kMaxOperandsAndOutputsPerFusion. This puts an upper bound on the number of
69 // parameters to the kernel, working around the correctness problem.
70 //
71 // This limit is also often good for performance. In a fusion with many
72 // operands, each GPU thread likely has to do a lot of work, and so possibly
73 // uses a lot of registers, thus limiting occupancy.
FusionWouldBeTooLarge(const HloInstruction * a,const HloInstruction * b)74 /*static*/ bool GpuInstructionFusion::FusionWouldBeTooLarge(
75 const HloInstruction* a, const HloInstruction* b) {
76 // Compute the number of outputs of the (possibly multi-output) fusion node
77 // we're considering creating.
78 //
79 // This isn't precise; we may be off by one if
80 // - We're creating a multi-output fusion out of two non-MOFs. Creating a
81 // MOF adds a new buffer, namely, the tuple buffer.
82 // - We're merging two MOFs. In this case, we should count the tuple buffer
83 // only once.
84 // - WLOG there's an edge from `a` to `b` and `b` is the only consumer of
85 // `a`. In this case the result of `a` is not part of the output of the
86 // fusion.
87 //
88 // But because this is a heuristic and our limit
89 // kMaxOperandsAndOutputsPerFusion is a large value (so +/- 1 doesn't make a
90 // big difference), we ignore this small inaccuracy in favor of simplicity.
91 int64 num_output_buffers = ShapeUtil::SubshapeCount(a->shape()) +
92 ShapeUtil::SubshapeCount(b->shape());
93
94 // The new fusion will have no more operands and outputs than
95 // producer_operands + consumer_operands - 1 + num_output_buffers
96 // (minus one because we may be fusing a producer->consumer edge between `a`
97 // and `b`).
98 //
99 // This fact may be enough to let us avoid having to compute the true total
100 // number of operands, which can be expensive.
101 if (a->operand_count() + b->operand_count() - 1 + num_output_buffers <=
102 kMaxOperandsAndOutputsPerFusion) {
103 return false;
104 }
105
106 // Compute the precise number of operands to the new fusion.
107 absl::flat_hash_set<const HloInstruction*> operands(a->operands().begin(),
108 a->operands().end());
109 operands.insert(b->operands().begin(), b->operands().end());
110 // If there's an edge between `a` and `b`, don't count it: We're fusing that
111 // producer -> consumer relationship.
112 operands.erase(a);
113 operands.erase(b);
114 return operands.size() + num_output_buffers > kMaxOperandsAndOutputsPerFusion;
115 }
116
ShouldFuseInexpensiveChecks(HloInstruction * consumer,int64 operand_index)117 bool GpuInstructionFusion::ShouldFuseInexpensiveChecks(HloInstruction* consumer,
118 int64 operand_index) {
119 HloInstruction* producer = consumer->mutable_operand(operand_index);
120
121 // Check if we can use output fusion for (A @ B) * alpha
122 if (producer->opcode() == HloOpcode::kDot ||
123 (producer->opcode() == HloOpcode::kFusion &&
124 producer->fused_expression_root()->opcode() == HloOpcode::kDot)) {
125 int64 other_operand_index = 1 - operand_index;
126 HloInstruction* op1 = nullptr;
127 HloInstruction* op2 = nullptr;
128 if (consumer->operand_count() == 1 &&
129 consumer->opcode() == HloOpcode::kFusion &&
130 consumer->fusion_kind() == HloInstruction::FusionKind::kLoop &&
131 Match(consumer->fused_expression_root(),
132 match::Op()
133 .WithOpcode(HloOpcode::kMultiply)
134 .WithOperand(0, match::Op(&op1))
135 .WithOperand(1, match::Op(&op2)))) {
136 CHECK(op1 != nullptr && op2 != nullptr);
137 // If 'consumer' is a fusion node, it should consist of a broadcast of a
138 // scalar constant fused into a multiply, but nothing more. So one operand
139 // should be a parameter, and the other should be a broadcast.
140 if (op1->opcode() != HloOpcode::kParameter) {
141 std::swap(op1, op2);
142 }
143 if (op1->opcode() != HloOpcode::kParameter ||
144 op2->opcode() != HloOpcode::kBroadcast) {
145 return false;
146 }
147 if (IsIEEEFloatingPointScalarConstant(op2->operand(0))) {
148 return true;
149 }
150 } else if (consumer->operand_count() == 2 &&
151 consumer->opcode() == HloOpcode::kMultiply) {
152 const HloInstruction* alpha = consumer->operand(other_operand_index);
153 // Fuse if 'alpha' is a broadcast of a scalar constant.
154 if (alpha->opcode() == HloOpcode::kBroadcast &&
155 alpha->dimensions().empty() &&
156 IsIEEEFloatingPointScalarConstant(alpha->operand(0))) {
157 return true;
158 }
159 } else if (consumer->operand_count() == 2 &&
160 consumer->opcode() == HloOpcode::kAdd &&
161 consumer->operand(other_operand_index) != producer) {
162 // Fuse a bias add into the output of the dot.
163 return true;
164 }
165 }
166
167 // Only allow fusing transpose or broadcast into an output fusion that is
168 // implemented as a Gemm call.
169 if (consumer->opcode() == HloOpcode::kFusion &&
170 consumer->fusion_kind() == HloInstruction::FusionKind::kOutput &&
171 ImplementedAsGemm(*consumer)) {
172 auto producer_operand_index = consumer->operand_index(producer);
173 auto fused_parameter = consumer->fused_parameter(producer_operand_index);
174 const std::vector<HloInstruction*>& fused_parameter_users =
175 fused_parameter->users();
176 if (fused_parameter_users.size() != 1) {
177 return false;
178 }
179 if (producer->opcode() == HloOpcode::kTranspose) {
180 // Check that the transpose is an operand of a dot.
181 return fused_parameter_users[0]->opcode() == HloOpcode::kDot;
182 }
183 if (producer->opcode() == HloOpcode::kBroadcast) {
184 // Check that the broadcast is a broadcast of a scalar constant into a
185 // multiply.
186 return producer->dimensions().empty() &&
187 IsIEEEFloatingPointScalarConstant(producer->operand(0)) &&
188 fused_parameter_users[0]->opcode() == HloOpcode::kMultiply;
189 }
190 return false;
191 }
192
193 // Other output fusions are not currently supported on GPUs.
194 if (producer->opcode() == HloOpcode::kFusion) {
195 return false;
196 }
197
198 // RNG operations are not currently parallel-friendly on GPU.
199 if (producer->opcode() == HloOpcode::kRng) {
200 return false;
201 }
202
203 // Do not fuse to-vector reduction into other consumers. They should be
204 // unfused or the root of a kInput fusion.
205 if (IsReductionToVector(*producer)) {
206 return false;
207 }
208
209 // Scatter is only supported at the root of a kInput fusion.
210 if (producer->opcode() == HloOpcode::kScatter) {
211 return false;
212 }
213
214 // Do not fuse into reduce input fusions if the resulting kernel would suffer
215 // from poor data locality (due to unfriendly input layouts).
216 if (IsInputFusibleReduction(*consumer) &&
217 !LayoutsAreReduceInputFusionFriendly(*producer, *consumer)) {
218 return false;
219 }
220
221 // We can't fuse library calls, so if a user of such an op could become a
222 // bitcast, leave it unfused. See `xla::InstructionFusion::ShouldFuse` for
223 // further rationale.
224 if (producer->CouldBeBitcast() &&
225 ImplementedAsLibraryCall(*producer->operand(0))) {
226 return false;
227 }
228
229 // Cost condition: not fuse (simple, expensive producers) and (consumers who
230 // reuse operand elements).
231 if (producer->opcode() != HloOpcode::kFusion &&
232 consumer->ReusesOperandElements(operand_index) &&
233 is_expensive(*producer)) {
234 return false;
235 }
236
237 // Fuse scalar constants into loop fusion nodes. This reduces the number of
238 // parameters and makes matching scalar broadcasts easier.
239 //
240 // Don't fuse other constants: Unfused constants in GPU land can be
241 // represented as an external constant (i.e. not emitted in LLVM IR / PTX),
242 // but fused constants are handled by shrared CPU/GPU code and always emitted
243 // in the IR/PTX. The external constant representation makes for faster
244 // compiles and significantly smaller assembly code.
245 if (producer->opcode() == HloOpcode::kConstant) {
246 return ShapeUtil::IsEffectiveScalar(producer->shape()) &&
247 consumer->opcode() == HloOpcode::kFusion;
248 }
249
250 if (!IsFusible(*producer) || !IsFusible(*consumer) ||
251 !InstructionFusion::ShouldFuse(consumer, operand_index)) {
252 return false;
253 }
254 return true;
255 }
256
ShouldFuse(HloInstruction * consumer,int64 operand_index)257 bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
258 int64 operand_index) {
259 if (!ShouldFuseInexpensiveChecks(consumer, operand_index)) {
260 return false;
261 }
262 auto producer = consumer->operand(operand_index);
263 // The following checks are potentially expensive.
264 if (FusionWouldBeTooLarge(consumer, producer)) {
265 return false;
266 }
267 // Also check that our emitter can handle the fusion node. We currently can
268 // have exponential time/memory requirements for emitting certain fusion
269 // kernels, in which case we don't want to fuse.
270 // TODO(b/119692968): Remove this once we have fixed our fusion emitter.
271 return !FusedIrEmitter::IsFusedIrEmitterInefficient(consumer, producer);
272 }
273
ShouldFuseIntoMultiOutput(HloInstruction * consumer,int64 operand_index)274 bool GpuInstructionFusion::ShouldFuseIntoMultiOutput(HloInstruction* consumer,
275 int64 operand_index) {
276 return false;
277 }
278
ChooseKind(const HloInstruction * producer,const HloInstruction * consumer)279 HloInstruction::FusionKind GpuInstructionFusion::ChooseKind(
280 const HloInstruction* producer, const HloInstruction* consumer) {
281 if (IsReductionToVector(*consumer) ||
282 consumer->opcode() == HloOpcode::kScatter) {
283 return HloInstruction::FusionKind::kInput;
284 }
285 if (producer->opcode() == HloOpcode::kDot ||
286 (producer->opcode() == HloOpcode::kFusion &&
287 producer->fused_expression_root()->opcode() == HloOpcode::kDot)) {
288 return HloInstruction::FusionKind::kOutput;
289 }
290 if (HloOpcode::kFusion == consumer->opcode()) {
291 return consumer->fusion_kind();
292 }
293 return InstructionFusion::ChooseKind(producer, consumer);
294 }
295
296 } // namespace gpu
297 } // namespace xla
298