1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
17
18 #include <algorithm>
19 #include <iterator>
20 #include <stack>
21 #include <vector>
22
23 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
24 #include "tensorflow/compiler/xla/service/hlo_computation.h"
25 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
26 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
27 #include "tensorflow/compiler/xla/shape.h"
28 #include "tensorflow/compiler/xla/shape_util.h"
29
30 namespace xla {
31 namespace gpu {
32 namespace {
33
34 // The amount of shared memory a CUDA kernel can use.
35 //
36 // Stay on the conservative side, this is smaller than full 64kB, but allows
37 // some extra space for cache.
38 int64 kSharedMemoryBudgetInBytes = 40000;
39
AppendParams(const HloInstruction & instr,std::vector<HloInstruction * > * params)40 void AppendParams(const HloInstruction& instr,
41 std::vector<HloInstruction*>* params) {
42 if (instr.opcode() == HloOpcode::kFusion) {
43 params->insert(std::end(*params), std::begin(instr.fused_parameters()),
44 std::end(instr.fused_parameters()));
45 } else {
46 for (HloInstruction* operand : instr.operands()) {
47 params->push_back(operand);
48 }
49 }
50 }
51
IfFusedReadsElementsMultipleTimes(const HloInstruction & instr)52 bool IfFusedReadsElementsMultipleTimes(const HloInstruction& instr) {
53 CHECK_NE(instr.opcode(), HloOpcode::kFusion) << "`instr` has to be unfused.";
54 if (instr.opcode() == HloOpcode::kReduce &&
55 !IsReductionFromOrToContiguousDimensions(instr)) {
56 return true;
57 }
58 // Avoid fusing reduce-window when stride is less than window size to minimize
59 // the number of reads of the same elements.
60 if (instr.opcode() == HloOpcode::kReduceWindow) {
61 for (const auto& dim : instr.window().dimensions()) {
62 if (dim.size() > dim.stride()) {
63 return true;
64 }
65 }
66 }
67 return false;
68 }
69
ExtractRelativeOrderOfNontrivialDims(const Shape & shape)70 std::vector<int64> ExtractRelativeOrderOfNontrivialDims(const Shape& shape) {
71 std::vector<int64> relative_order;
72 for (int64 dim : LayoutUtil::MinorToMajor(shape)) {
73 if (shape.dimensions(dim) > 1) {
74 relative_order.push_back(dim);
75 }
76 }
77 // Now normalize the dimensions to values between 0 and true rank - 1.
78 std::vector<int64> sorted_dims = relative_order;
79 std::sort(sorted_dims.begin(), sorted_dims.end());
80 for (int64& dim : relative_order) {
81 int64 sorted_index = std::distance(
82 sorted_dims.begin(),
83 std::lower_bound(sorted_dims.begin(), sorted_dims.end(), dim));
84 dim = sorted_index;
85 }
86 return relative_order;
87 }
88
89 } // namespace
90
LayoutsAreReduceInputFusionFriendly(const HloInstruction & producer,const HloInstruction & reduce)91 bool LayoutsAreReduceInputFusionFriendly(const HloInstruction& producer,
92 const HloInstruction& reduce) {
93 std::vector<HloInstruction*> params;
94 AppendParams(producer, ¶ms);
95 AppendParams(reduce, ¶ms);
96 int64 max_true_rank = -1;
97 std::vector<int64> max_rank_order;
98 for (HloInstruction* param : params) {
99 if (param->shape().IsArray() &&
100 ShapeUtil::TrueRank(param->shape()) > max_true_rank) {
101 max_true_rank = ShapeUtil::TrueRank(param->shape());
102 max_rank_order = ExtractRelativeOrderOfNontrivialDims(param->shape());
103 }
104 }
105 return absl::c_all_of(params, [&](HloInstruction* param) {
106 return !param->shape().IsArray() ||
107 ShapeUtil::TrueRank(param->shape()) < max_true_rank ||
108 ExtractRelativeOrderOfNontrivialDims(param->shape()) ==
109 max_rank_order;
110 });
111 }
112
IsReduceInputFusion(const HloInstruction & instr)113 bool IsReduceInputFusion(const HloInstruction& instr) {
114 if (instr.IsMultiOutputFusion()) {
115 for (const HloInstruction* operand :
116 instr.fused_expression_root()->operands()) {
117 if (IsReductionFromOrToContiguousDimensions(*operand)) {
118 CHECK(instr.IsInputFusion())
119 << " Multi-output fusion rooted at reduction-to-vector ops must be "
120 "of kind kInput: "
121 << instr.ToString();
122 return true;
123 }
124 }
125 } else if (instr.opcode() == HloOpcode::kFusion &&
126 IsReductionFromOrToContiguousDimensions(
127 *instr.fused_expression_root())) {
128 CHECK(instr.IsInputFusion())
129 << " Fusion rooted at reduction-to-vector op must be of kind kInput: "
130 << instr.ToString();
131 return true;
132 }
133 return false;
134 }
135
IsInputFusibleReduction(const HloInstruction & instr)136 bool IsInputFusibleReduction(const HloInstruction& instr) {
137 // TODO(b/129089333): Don't fuse variadic reduce.
138 if (instr.opcode() == HloOpcode::kReduce && instr.shape().IsTuple()) {
139 return false;
140 }
141
142 return IsReduceInputFusion(instr) ||
143 IsReductionFromOrToContiguousDimensions(instr);
144 }
145
GetRealHeroForMultiOutputFusion(const HloInstruction & instr)146 const HloInstruction* GetRealHeroForMultiOutputFusion(
147 const HloInstruction& instr) {
148 if (instr.opcode() != HloOpcode::kFusion) {
149 return &instr;
150 }
151 auto fused_expression_root = instr.fused_expression_root();
152 if (!instr.IsMultiOutputFusion()) {
153 return fused_expression_root;
154 }
155 // If possible, we want to pick a reduction-from-or-to-contiguous-dims
156 // operand of the fusion root, because it has the most constraints.
157 for (const auto* inst : fused_expression_root->operands()) {
158 if (IsReductionFromOrToContiguousDimensions(*inst)) {
159 return inst;
160 }
161 }
162 return fused_expression_root->operands()[0];
163 }
164
ShapesCompatibleForMultiOutputFusion(const HloInstruction & instr1,const HloInstruction & instr2)165 bool ShapesCompatibleForMultiOutputFusion(const HloInstruction& instr1,
166 const HloInstruction& instr2) {
167 // Multi-output fusion kernels share a common parallel loop. The loop
168 // dimensions are determined by instruction shapes.
169 auto get_loop_shape = [&](const HloInstruction* element_instr) {
170 // Special-case reduction-to-vector ops: The loop dimensions are determined
171 // by the shape of the first operand.
172 if (IsReductionFromOrToContiguousDimensions(*element_instr)) {
173 return element_instr->operand(0)->shape();
174 }
175 return element_instr->shape();
176 };
177
178 // All shapes of the root tuple of multi-output fusions should agree, i.e. all
179 // root ops should have equal output shapes. An exception are
180 // reduction-to-vector ops. Here the input shapes of the reduction (first
181 // operand shape) and the reduction dimensions need to match.
182 auto* instr_1 = GetRealHeroForMultiOutputFusion(instr1);
183 auto* instr_2 = GetRealHeroForMultiOutputFusion(instr2);
184 if (IsReductionFromOrToContiguousDimensions(*instr_1) &&
185 IsReductionFromOrToContiguousDimensions(*instr_2) &&
186 !AreFusedReductionOutputsConsistent({instr_1, instr_2}, instr_1)) {
187 return false;
188 }
189 // The elementwise output shapes must be the same (including layout).
190 return ShapeUtil::EqualIgnoringElementType(get_loop_shape(instr_1),
191 get_loop_shape(instr_2));
192 }
193
IsInputFusibleScatter(const HloInstruction & instr)194 bool IsInputFusibleScatter(const HloInstruction& instr) {
195 if (instr.opcode() == HloOpcode::kScatter ||
196 (instr.opcode() == HloOpcode::kFusion &&
197 instr.fusion_kind() == HloInstruction::FusionKind::kInput &&
198 instr.fused_expression_root()->opcode() == HloOpcode::kScatter)) {
199 return true;
200 }
201 return false;
202 }
203
IsInputFusible(const HloInstruction & instr)204 bool IsInputFusible(const HloInstruction& instr) {
205 // Input fusion only handles non-elemental reduction and scatter operations.
206 return instr.IsFusible() &&
207 (IsInputFusibleReduction(instr) || IsInputFusibleScatter(instr));
208 }
209
IsLoopFusible(const HloInstruction & instr)210 bool IsLoopFusible(const HloInstruction& instr) {
211 // Don't fuse get-tuple-element on GPU: We can, but it's slower than not
212 // fusing. We never generate kernels for unfused GTEs. Instead, if an
213 // unfused GTE is an input to a kernel (including a fusion kernel), we
214 // compute the address of the GTE at the top of the kernel. Often we know the
215 // address of the GTE result statically, so we can do this without chasing any
216 // pointers.
217 return instr.IsFusible() &&
218 ((instr.IsElementwise() && instr.operand_count() > 0) ||
219 instr.opcode() == HloOpcode::kBitcast ||
220 instr.opcode() == HloOpcode::kBroadcast ||
221 instr.opcode() == HloOpcode::kConcatenate ||
222 instr.opcode() == HloOpcode::kDynamicSlice ||
223 instr.opcode() == HloOpcode::kDynamicUpdateSlice ||
224 (instr.opcode() == HloOpcode::kFusion &&
225 instr.fusion_kind() == HloInstruction::FusionKind::kLoop) ||
226 instr.opcode() == HloOpcode::kGather ||
227 instr.opcode() == HloOpcode::kIota ||
228 instr.opcode() == HloOpcode::kPad ||
229 (instr.opcode() == HloOpcode::kReduce &&
230 !IsReductionFromOrToContiguousDimensions(instr) &&
231 !instr.shape().IsTuple()) || // TODO(b/129089333): Don't fuse
232 // variadic reductions.
233 instr.opcode() == HloOpcode::kReduceWindow ||
234 instr.opcode() == HloOpcode::kReshape ||
235 instr.opcode() == HloOpcode::kReverse ||
236 instr.opcode() == HloOpcode::kSlice ||
237 instr.opcode() == HloOpcode::kConstant ||
238 instr.opcode() == HloOpcode::kTranspose);
239 }
240
IsProducerConsumerFusible(const HloInstruction & producer,const HloInstruction & consumer)241 bool IsProducerConsumerFusible(const HloInstruction& producer,
242 const HloInstruction& consumer) {
243 if (!IsLoopFusible(producer)) {
244 VLOG(5) << "Producer " << producer.name() << " is not loop-fusible";
245 return false;
246 }
247
248 if (!IsInputFusible(consumer) && !IsLoopFusible(consumer)) {
249 VLOG(5) << "Consumer " << consumer.name()
250 << "is not input-fusible and not loop-fusible";
251 return false;
252 }
253
254 // Skip multiple output fusion. It's not yet supported.
255 if (producer.IsMultiOutputFusion()) {
256 VLOG(5) << "Producer " << producer.name()
257 << " is not fusible as it is a multi-output fusion";
258 return false;
259 }
260
261 if (CreatesNestedLoop(producer, consumer)) {
262 VLOG(5) << "Fusing " << producer.name() << " into " << consumer.name()
263 << " creates nested loop";
264 return false;
265 }
266
267 // Do not fuse into reduce input fusions if the resulting kernel would suffer
268 // from poor data locality (due to unfriendly input layouts).
269 if (IsInputFusibleReduction(consumer) &&
270 !LayoutsAreReduceInputFusionFriendly(producer, consumer)) {
271 VLOG(5) << "Layout of " << producer.name()
272 << " is not fusion-friendly for consumer reduction "
273 << consumer.name();
274 return false;
275 }
276
277 // Fuse scalar constants into loop fusion nodes. This reduces the number of
278 // parameters and makes matching scalar broadcasts easier.
279 //
280 // Don't fuse other constants: Unfused constants in GPU land can be
281 // represented as an external constant (i.e. not emitted in LLVM IR / PTX),
282 // but fused constants are handled by shrared CPU/GPU code and always emitted
283 // in the IR/PTX. The external constant representation makes for faster
284 // compiles and significantly smaller assembly code.
285 if (producer.opcode() == HloOpcode::kConstant &&
286 (!ShapeUtil::IsEffectiveScalar(producer.shape()) ||
287 consumer.opcode() != HloOpcode::kFusion)) {
288 VLOG(5) << "Not fusing constant " << producer.name() << " into "
289 << consumer.name();
290 return false;
291 }
292
293 return true;
294 }
295
IsProducerConsumerMultiOutputFusible(const HloInstruction & producer,const HloInstruction & consumer)296 bool IsProducerConsumerMultiOutputFusible(const HloInstruction& producer,
297 const HloInstruction& consumer) {
298 // Skip multiple output fusion. It's not yet supported.
299 if (producer.IsMultiOutputFusion()) {
300 return false;
301 }
302
303 if (!IsLoopFusible(producer) || !IsFusibleAsMultiOutputFusionRoot(consumer)) {
304 return false;
305 }
306 if (CreatesNestedLoop(producer, consumer)) {
307 return false;
308 }
309 if (!ShapesCompatibleForMultiOutputFusion(producer, consumer)) {
310 return false;
311 }
312 if (!LayoutsAreReduceInputFusionFriendly(producer, consumer)) {
313 return false;
314 }
315 return true;
316 }
317
318 // Returns shared memory usage for a given instruction in bytes.
SharedMemoryUsage(const HloInstruction & instr)319 static int64 SharedMemoryUsage(const HloInstruction& instr) {
320 // For now we are only fusing reductions.
321 if (instr.opcode() == HloOpcode::kReduce &&
322 IsReductionFromOrToContiguousDimensions(instr)) {
323 ReductionDimensions reduction_info =
324 GetReductionKindAndContiguousComponents(instr);
325 int64 primitive_size =
326 ShapeUtil::ByteSizeOfPrimitiveType(instr.shape().element_type());
327 if (reduction_info.is_row_reduction) {
328 // __shared__[32] is used for row reduction.
329 return 32 * primitive_size;
330 } else {
331 // __shared__[2][32][33] cache is used for column reduction ("2" comes
332 // from potential x-tiling).
333 return 2 * 32 * 33 * primitive_size;
334 }
335 } else if (instr.opcode() == HloOpcode::kFusion) {
336 int64 sum = 0;
337 for (const HloInstruction* hlo :
338 instr.fused_instructions_computation()->MakeInstructionPostOrder()) {
339 sum += SharedMemoryUsage(*hlo);
340 }
341 return sum;
342 }
343 // Other fused expressions for now don't need the shared memory budget.
344 return 0;
345 }
346
347 // This function limits the maximum number of operands to a fusion, and the
348 // amount of shared memory which can be consumed by the fusion.
349 //
350 // There's a cap on how many parameters we can pass to a CUDA kernel, but
351 // exactly what that limit is hazy, as it depends on (among other things) how
352 // much GPU constant memory is in use for other purposes.
353 //
354 // Moreover, we don't even know at the point that we're running fusion how many
355 // arguments the CUDA kernel for a fusion node will have: It depends on buffer
356 // assignment, where we will decide which of the fusion's operands live in XLA's
357 // big temp buffer versus in other allocations.
358 //
359 // As a heuristic, we simply cap the number of fusion operands plus outputs at
360 // kMaxOperandsAndOutputsPerFusion. This puts an upper bound on the number of
361 // parameters to the kernel, working around the correctness problem.
362 //
363 // This limit is also often good for performance. In a fusion with many
364 // operands, each GPU thread likely has to do a lot of work, and so possibly
365 // uses a lot of registers, thus limiting occupancy.
366 //
367 // If the fusion is a producer/consumer fusion and instr1 is the
368 // consumer and instr2 is the producer, set is_consumer_producer_fusion
369 // to true to enable more fusion.
FusionWouldBeTooLarge(const HloInstruction & instr1,const HloInstruction & instr2,bool is_consumer_producer_fusion)370 bool FusionWouldBeTooLarge(const HloInstruction& instr1,
371 const HloInstruction& instr2,
372 bool is_consumer_producer_fusion) {
373 if (SharedMemoryUsage(instr1) + SharedMemoryUsage(instr2) >
374 kSharedMemoryBudgetInBytes) {
375 VLOG(5) << "Shared memory usage of fusion of " << instr1.ToString()
376 << " and " << instr2.ToString() << " would be over the budget of "
377 << kSharedMemoryBudgetInBytes << "B";
378 return true;
379 }
380
381 // Compute the number of outputs of the (possibly multi-output) fusion node
382 // we're considering creating.
383 //
384 // This isn't precise; we may be off by one if
385 // - We're creating a multi-output fusion out of two non-MOFs. Creating a
386 // MOF adds a new buffer, namely, the tuple buffer.
387 // - We're merging two MOFs. In this case, we should count the tuple buffer
388 // only once.
389 // - WLOG there's an edge from `a` to `b` and `b` is the only consumer of
390 // `a`. In this case the result of `a` is not part of the output of the
391 // fusion.
392 //
393 // But because this is a heuristic and our limit
394 // kMaxOperandsAndOutputsPerFusion is a large value (so +/- 1 doesn't make a
395 // big difference), we ignore this small inaccuracy in favor of simplicity.
396 int64 num_output_buffers = ShapeUtil::SubshapeCount(instr1.shape()) +
397 ShapeUtil::SubshapeCount(instr2.shape());
398
399 // The new fusion will have no more operands and outputs than
400 // producer_operands + consumer_operands - 1 + num_output_buffers
401 // (minus one because we may be fusing a producer->consumer edge between `a`
402 // and `b`).
403 //
404 // This fact may be enough to let us avoid having to compute the true total
405 // number of operands, which can be expensive.
406 if (instr1.operand_count() + instr2.operand_count() - 1 +
407 num_output_buffers <=
408 kMaxOperandsAndOutputsPerFusion) {
409 return false;
410 } else {
411 VLOG(5) << "Operand count of "
412 << "(" << instr1.ToString() << " ) = " << instr1.operand_count()
413 << " and ( " << instr2.ToString()
414 << " ) = " << instr2.operand_count()
415 << " and num_output_buffers = " << num_output_buffers
416 << " is bigger than the bound of "
417 << kMaxOperandsAndOutputsPerFusion;
418 }
419
420 // Compute the precise number of operands to the new fusion.
421 absl::flat_hash_set<const HloInstruction*> operands(instr1.operands().begin(),
422 instr1.operands().end());
423 operands.insert(instr2.operands().begin(), instr2.operands().end());
424 // If there's an edge between `a` and `b`, don't count it: We're fusing that
425 // producer -> consumer relationship.
426 operands.erase(&instr1);
427 operands.erase(&instr2);
428
429 // If we generate the same numbers of inputs and outputs as
430 // before, it won't be bigger after fusion. So accept the fusion.
431 // As this is a consumer_producer fusion, this does not change the
432 // consumer numbers of output. So no need to check it.
433 if (is_consumer_producer_fusion &&
434 operands.size() <= instr1.operands().size()) {
435 return false;
436 }
437
438 // Does the new fusion have more operands and outputs than the max?
439 return operands.size() + num_output_buffers > kMaxOperandsAndOutputsPerFusion;
440 }
441
CreatesNestedLoop(const HloInstruction & producer,const HloInstruction & consumer)442 bool CreatesNestedLoop(const HloInstruction& producer,
443 const HloInstruction& consumer) {
444 // If producer does not have an instruction that codegens a loop then there is
445 // nothing to do.
446 auto producer_has_loop_codegen = [&](const HloInstruction& instr) {
447 if (producer.opcode() != HloOpcode::kFusion) {
448 return IfFusedReadsElementsMultipleTimes(producer);
449 }
450 for (const auto& instr : producer.fused_instructions()) {
451 if (IfFusedReadsElementsMultipleTimes(*instr)) {
452 return true;
453 }
454 }
455 return false;
456 };
457 if (!producer_has_loop_codegen(producer)) {
458 return false;
459 }
460
461 // If consumer is a non-fusion instruction then we have to check if it
462 // generates a loop.
463 if (consumer.opcode() != HloOpcode::kFusion) {
464 return IfFusedReadsElementsMultipleTimes(consumer);
465 }
466
467 // If consumer is a fusion then we have to check if the output of producer is
468 // used directly or indirectly as an input to an HLO instruction that
469 // generates a loop, i.e. there is a path in the graph from an operand
470 // corresponding to the producer to an HLO instruction generating a loop in
471 // the consumer.
472 for (const HloInstruction* operand : consumer.operands()) {
473 if (operand != &producer) {
474 continue;
475 }
476
477 const HloInstruction* root =
478 consumer.fused_instructions_computation()->parameter_instruction(
479 consumer.operand_index(operand));
480
481 std::stack<const HloInstruction*> dfs;
482 dfs.push(root);
483 absl::flat_hash_set<const HloInstruction*> visited;
484 while (!dfs.empty()) {
485 const HloInstruction* cur = dfs.top();
486 dfs.pop();
487
488 if (visited.contains(cur)) {
489 continue;
490 }
491 visited.insert(cur);
492
493 if (IfFusedReadsElementsMultipleTimes(*cur)) {
494 return true;
495 }
496 for (const auto& user : cur->users()) {
497 if (visited.contains(user)) {
498 continue;
499 }
500 dfs.push(user);
501 }
502 }
503 }
504 return false;
505 }
506
IsFusibleAsMultiOutputFusionRoot(const HloInstruction & instr)507 bool IsFusibleAsMultiOutputFusionRoot(const HloInstruction& instr) {
508 // We can fuse reduces and loop fusions. Elementwise instructions can be fused
509 // with any other instruction.
510 // Note that scatter cannot be the root of a multi-output fusion because
511 // its emitter doesn't support it.
512
513 return instr.IsFusible() &&
514 (IsInputFusibleReduction(instr) ||
515 instr.IsLoopFusion() || // TODO(b/130013493): Use IsLoopFusible here.
516 instr.IsElementwise());
517 }
518
ChooseFusionKind(const HloInstruction &,const HloInstruction & consumer)519 HloInstruction::FusionKind ChooseFusionKind(const HloInstruction& /*producer*/,
520 const HloInstruction& consumer) {
521 return IsInputFusible(consumer) ? HloInstruction::FusionKind::kInput
522 : HloInstruction::FusionKind::kLoop;
523 }
524
IsConsumerTheOnlyNonRootUser(const HloInstruction & instr,const HloInstruction & consumer)525 bool IsConsumerTheOnlyNonRootUser(const HloInstruction& instr,
526 const HloInstruction& consumer) {
527 return absl::c_all_of(instr.users(), [&](const HloInstruction* user) {
528 if (user->opcode() == HloOpcode::kGetTupleElement) {
529 // Skip GTE.
530 return IsConsumerTheOnlyNonRootUser(*user, consumer);
531 }
532 if (user == &consumer) {
533 // `user` is `consumer`.
534 return true;
535 }
536 if (user == user->parent()->root_instruction()) {
537 // Consumed by ROOT.
538 return true;
539 }
540 return false;
541 });
542 }
543
544 } // namespace gpu
545 } // namespace xla
546