• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
17 
18 #include "absl/container/flat_hash_set.h"
19 #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
20 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
21 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
22 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
23 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
24 #include "tensorflow/compiler/xla/shape_util.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 
27 namespace xla {
28 namespace gpu {
29 
30 namespace {
31 
IsIEEEFloatingPointScalarConstant(const HloInstruction * constant)32 bool IsIEEEFloatingPointScalarConstant(const HloInstruction* constant) {
33   if (constant->opcode() != HloOpcode::kConstant ||
34       !ShapeUtil::IsScalar(constant->shape())) {
35     return false;
36   }
37   auto type = constant->shape().element_type();
38   return type == F16 || type == F32 || type == F64;
39 }
40 
41 }  // namespace
42 
IsExpensive(const HloInstruction & instruction)43 /*static*/ bool GpuInstructionFusion::IsExpensive(
44     const HloInstruction& instruction) {
45   switch (instruction.opcode()) {
46     // We say that floating-point division is cheap on the GPU.
47     case HloOpcode::kDivide:
48       return !ShapeUtil::ElementIsFloating(instruction.shape()) &&
49              InstructionFusion::IsExpensive(instruction);
50 
51     default:
52       return InstructionFusion::IsExpensive(instruction);
53   }
54 }
55 
56 // This function limits the maximum number of operands to a fusion.
57 //
58 // There's a cap on how many parameters we can pass to a CUDA kernel, but
59 // exactly what that limit is hazy, as it depends on (among other things) how
60 // much GPU constant memory is in use for other purposes.
61 //
62 // Moreover, we don't even know at the point that we're running fusion how many
63 // arguments the CUDA kernel for a fusion node will have: It depends on buffer
64 // assignment, where we will decide which of the fusion's operands live in XLA's
65 // big temp buffer versus in other allocations.
66 //
67 // As a heuristic, we simply cap the number of fusion operands plus outputs at
68 // kMaxOperandsAndOutputsPerFusion.  This puts an upper bound on the number of
69 // parameters to the kernel, working around the correctness problem.
70 //
71 // This limit is also often good for performance.  In a fusion with many
72 // operands, each GPU thread likely has to do a lot of work, and so possibly
73 // uses a lot of registers, thus limiting occupancy.
FusionWouldBeTooLarge(const HloInstruction * a,const HloInstruction * b)74 /*static*/ bool GpuInstructionFusion::FusionWouldBeTooLarge(
75     const HloInstruction* a, const HloInstruction* b) {
76   // Compute the number of outputs of the (possibly multi-output) fusion node
77   // we're considering creating.
78   //
79   // This isn't precise; we may be off by one if
80   //  - We're creating a multi-output fusion out of two non-MOFs.  Creating a
81   //    MOF adds a new buffer, namely, the tuple buffer.
82   //  - We're merging two MOFs.  In this case, we should count the tuple buffer
83   //    only once.
84   //  - WLOG there's an edge from `a` to `b` and `b` is the only consumer of
85   //    `a`.  In this case the result of `a` is not part of the output of the
86   //    fusion.
87   //
88   // But because this is a heuristic and our limit
89   // kMaxOperandsAndOutputsPerFusion is a large value (so +/- 1 doesn't make a
90   // big difference), we ignore this small inaccuracy in favor of simplicity.
91   int64 num_output_buffers = ShapeUtil::SubshapeCount(a->shape()) +
92                              ShapeUtil::SubshapeCount(b->shape());
93 
94   // The new fusion will have no more operands and outputs than
95   //   producer_operands + consumer_operands - 1 + num_output_buffers
96   // (minus one because we may be fusing a producer->consumer edge between `a`
97   // and `b`).
98   //
99   // This fact may be enough to let us avoid having to compute the true total
100   // number of operands, which can be expensive.
101   if (a->operand_count() + b->operand_count() - 1 + num_output_buffers <=
102       kMaxOperandsAndOutputsPerFusion) {
103     return false;
104   }
105 
106   // Compute the precise number of operands to the new fusion.
107   absl::flat_hash_set<const HloInstruction*> operands(a->operands().begin(),
108                                                       a->operands().end());
109   operands.insert(b->operands().begin(), b->operands().end());
110   // If there's an edge between `a` and `b`, don't count it: We're fusing that
111   // producer -> consumer relationship.
112   operands.erase(a);
113   operands.erase(b);
114   return operands.size() + num_output_buffers > kMaxOperandsAndOutputsPerFusion;
115 }
116 
ShouldFuseInexpensiveChecks(HloInstruction * consumer,int64 operand_index)117 bool GpuInstructionFusion::ShouldFuseInexpensiveChecks(HloInstruction* consumer,
118                                                        int64 operand_index) {
119   HloInstruction* producer = consumer->mutable_operand(operand_index);
120 
121   // Check if we can use output fusion for (A @ B) * alpha
122   if (producer->opcode() == HloOpcode::kDot ||
123       (producer->opcode() == HloOpcode::kFusion &&
124        producer->fused_expression_root()->opcode() == HloOpcode::kDot)) {
125     int64 other_operand_index = 1 - operand_index;
126     HloInstruction* op1 = nullptr;
127     HloInstruction* op2 = nullptr;
128     if (consumer->operand_count() == 1 &&
129         consumer->opcode() == HloOpcode::kFusion &&
130         consumer->fusion_kind() == HloInstruction::FusionKind::kLoop &&
131         Match(consumer->fused_expression_root(),
132               match::Op()
133                   .WithOpcode(HloOpcode::kMultiply)
134                   .WithOperand(0, match::Op(&op1))
135                   .WithOperand(1, match::Op(&op2)))) {
136       CHECK(op1 != nullptr && op2 != nullptr);
137       // If 'consumer' is a fusion node, it should consist of a broadcast of a
138       // scalar constant fused into a multiply, but nothing more. So one operand
139       // should be a parameter, and the other should be a broadcast.
140       if (op1->opcode() != HloOpcode::kParameter) {
141         std::swap(op1, op2);
142       }
143       if (op1->opcode() != HloOpcode::kParameter ||
144           op2->opcode() != HloOpcode::kBroadcast) {
145         return false;
146       }
147       if (IsIEEEFloatingPointScalarConstant(op2->operand(0))) {
148         return true;
149       }
150     } else if (consumer->operand_count() == 2 &&
151                consumer->opcode() == HloOpcode::kMultiply) {
152       const HloInstruction* alpha = consumer->operand(other_operand_index);
153       // Fuse if 'alpha' is a broadcast of a scalar constant.
154       if (alpha->opcode() == HloOpcode::kBroadcast &&
155           alpha->dimensions().empty() &&
156           IsIEEEFloatingPointScalarConstant(alpha->operand(0))) {
157         return true;
158       }
159     } else if (consumer->operand_count() == 2 &&
160                consumer->opcode() == HloOpcode::kAdd &&
161                consumer->operand(other_operand_index) != producer) {
162       // Fuse a bias add into the output of the dot.
163       return true;
164     }
165   }
166 
167   // Only allow fusing transpose or broadcast into an output fusion that is
168   // implemented as a Gemm call.
169   if (consumer->opcode() == HloOpcode::kFusion &&
170       consumer->fusion_kind() == HloInstruction::FusionKind::kOutput &&
171       ImplementedAsGemm(*consumer)) {
172     auto producer_operand_index = consumer->operand_index(producer);
173     auto fused_parameter = consumer->fused_parameter(producer_operand_index);
174     const std::vector<HloInstruction*>& fused_parameter_users =
175         fused_parameter->users();
176     if (fused_parameter_users.size() != 1) {
177       return false;
178     }
179     if (producer->opcode() == HloOpcode::kTranspose) {
180       // Check that the transpose is an operand of a dot.
181       return fused_parameter_users[0]->opcode() == HloOpcode::kDot;
182     }
183     if (producer->opcode() == HloOpcode::kBroadcast) {
184       // Check that the broadcast is a broadcast of a scalar constant into a
185       // multiply.
186       return producer->dimensions().empty() &&
187              IsIEEEFloatingPointScalarConstant(producer->operand(0)) &&
188              fused_parameter_users[0]->opcode() == HloOpcode::kMultiply;
189     }
190     return false;
191   }
192 
193   // Other output fusions are not currently supported on GPUs.
194   if (producer->opcode() == HloOpcode::kFusion) {
195     return false;
196   }
197 
198   // RNG operations are not currently parallel-friendly on GPU.
199   if (producer->opcode() == HloOpcode::kRng) {
200     return false;
201   }
202 
203   // Do not fuse to-vector reduction into other consumers. They should be
204   // unfused or the root of a kInput fusion.
205   if (IsReductionToVector(*producer)) {
206     return false;
207   }
208 
209   // Scatter is only supported at the root of a kInput fusion.
210   if (producer->opcode() == HloOpcode::kScatter) {
211     return false;
212   }
213 
214   // Do not fuse into reduce input fusions if the resulting kernel would suffer
215   // from poor data locality (due to unfriendly input layouts).
216   if (IsInputFusibleReduction(*consumer) &&
217       !LayoutsAreReduceInputFusionFriendly(*producer, *consumer)) {
218     return false;
219   }
220 
221   // We can't fuse library calls, so if a user of such an op could become a
222   // bitcast, leave it unfused. See `xla::InstructionFusion::ShouldFuse` for
223   // further rationale.
224   if (producer->CouldBeBitcast() &&
225       ImplementedAsLibraryCall(*producer->operand(0))) {
226     return false;
227   }
228 
229   // Cost condition: not fuse (simple, expensive producers) and (consumers who
230   // reuse operand elements).
231   if (producer->opcode() != HloOpcode::kFusion &&
232       consumer->ReusesOperandElements(operand_index) &&
233       is_expensive(*producer)) {
234     return false;
235   }
236 
237   // Fuse scalar constants into loop fusion nodes. This reduces the number of
238   // parameters and makes matching scalar broadcasts easier.
239   //
240   // Don't fuse other constants: Unfused constants in GPU land can be
241   // represented as an external constant (i.e. not emitted in LLVM IR / PTX),
242   // but fused constants are handled by shrared CPU/GPU code and always emitted
243   // in the IR/PTX.  The external constant representation makes for faster
244   // compiles and significantly smaller assembly code.
245   if (producer->opcode() == HloOpcode::kConstant) {
246     return ShapeUtil::IsEffectiveScalar(producer->shape()) &&
247            consumer->opcode() == HloOpcode::kFusion;
248   }
249 
250   if (!IsFusible(*producer) || !IsFusible(*consumer) ||
251       !InstructionFusion::ShouldFuse(consumer, operand_index)) {
252     return false;
253   }
254   return true;
255 }
256 
ShouldFuse(HloInstruction * consumer,int64 operand_index)257 bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
258                                       int64 operand_index) {
259   if (!ShouldFuseInexpensiveChecks(consumer, operand_index)) {
260     return false;
261   }
262   auto producer = consumer->operand(operand_index);
263   // The following checks are potentially expensive.
264   if (FusionWouldBeTooLarge(consumer, producer)) {
265     return false;
266   }
267   // Also check that our emitter can handle the fusion node. We currently can
268   // have exponential time/memory requirements for emitting certain fusion
269   // kernels, in which case we don't want to fuse.
270   // TODO(b/119692968): Remove this once we have fixed our fusion emitter.
271   return !FusedIrEmitter::IsFusedIrEmitterInefficient(consumer, producer);
272 }
273 
ShouldFuseIntoMultiOutput(HloInstruction * consumer,int64 operand_index)274 bool GpuInstructionFusion::ShouldFuseIntoMultiOutput(HloInstruction* consumer,
275                                                      int64 operand_index) {
276   return false;
277 }
278 
ChooseKind(const HloInstruction * producer,const HloInstruction * consumer)279 HloInstruction::FusionKind GpuInstructionFusion::ChooseKind(
280     const HloInstruction* producer, const HloInstruction* consumer) {
281   if (IsReductionToVector(*consumer) ||
282       consumer->opcode() == HloOpcode::kScatter) {
283     return HloInstruction::FusionKind::kInput;
284   }
285   if (producer->opcode() == HloOpcode::kDot ||
286       (producer->opcode() == HloOpcode::kFusion &&
287        producer->fused_expression_root()->opcode() == HloOpcode::kDot)) {
288     return HloInstruction::FusionKind::kOutput;
289   }
290   if (HloOpcode::kFusion == consumer->opcode()) {
291     return consumer->fusion_kind();
292   }
293   return InstructionFusion::ChooseKind(producer, consumer);
294 }
295 
296 }  // namespace gpu
297 }  // namespace xla
298