• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
17 
18 #include "absl/container/flat_hash_set.h"
19 #include "tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.h"
20 #include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
21 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
22 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
23 #include "tensorflow/compiler/xla/service/hlo_query.h"
24 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
25 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 
29 namespace xla {
30 namespace gpu {
31 
32 namespace {
ElementIsF32OrF16(const Shape & shape)33 bool ElementIsF32OrF16(const Shape& shape) {
34   PrimitiveType type = shape.element_type();
35   return type == F32 || type == F16;
36 }
37 }  // namespace
38 
IsExpensive(const HloInstruction & instruction)39 /*static*/ bool GpuInstructionFusion::IsExpensive(
40     const HloInstruction& instruction) {
41   // We say that some floating-point math ops are cheap on the GPU. Unlike other
42   // intrinsics that can be expanded into many instructions, Div and Rsqrt are
43   // lowered into single hardware instructions.
44   switch (instruction.opcode()) {
45     case HloOpcode::kDivide:
46     case HloOpcode::kRsqrt:
47     case HloOpcode::kExp:
48       if (ElementIsF32OrF16(instruction.shape())) {
49         return false;
50       }
51       break;
52     default:
53       break;
54   }
55   return InstructionFusion::IsExpensive(instruction);
56 }
57 
ShouldFuseInexpensiveChecks(HloInstruction * consumer,int64_t operand_index)58 FusionDecision GpuInstructionFusion::ShouldFuseInexpensiveChecks(
59     HloInstruction* consumer, int64_t operand_index) {
60   HloInstruction* producer = consumer->mutable_operand(operand_index);
61 
62   // Output fusions are not currently supported on GPUs.
63   if (producer->opcode() == HloOpcode::kFusion) {
64     return "the producer is a fusion";
65   }
66   // Cost condition: not fuse (simple, expensive producers) and (consumers who
67   // reuse operand elements).
68   if (producer->opcode() != HloOpcode::kFusion && is_expensive(*producer) &&
69       ReusesOperandElements(consumer, operand_index)) {
70     return "the producer is expensive, and the consumer reuses inputs";
71   }
72 
73   if (NoFusionPossible fusible =
74           !IsProducerConsumerFusible(*producer, *consumer)) {
75     return !fusible;
76   }
77   if (NoFusionPossible fusible =
78           !InstructionFusion::ShouldFuse(consumer, operand_index)) {
79     return !fusible;
80   }
81   return {};
82 }
83 
ShouldFuse(HloInstruction * consumer,int64_t operand_index)84 FusionDecision GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
85                                                 int64_t operand_index) {
86   if (NoFusionPossible fusible =
87           !ShouldFuseInexpensiveChecks(consumer, operand_index)) {
88     return !fusible;
89   }
90 
91   auto producer = consumer->operand(operand_index);
92 
93   // The following checks are potentially expensive.
94   if (NoFusionPossible too_large =
95           !FusionFitsInBudget(*consumer, *producer,
96                               /*is_consumer_producer_fusion=*/true)) {
97     return !too_large;
98   }
99 
100   if (consumer->opcode() != HloOpcode::kFusion) {
101     return {};
102   }
103 
104   // Also check that our emitter can handle the fusion node. We currently can
105   // have exponential time/memory requirements for emitting certain fusion
106   // kernels, in which case we don't want to fuse.
107   // TODO(b/119692968): Remove this once we have fixed our fusion emitter.
108   if (fusion_node_evaluations_.find(consumer) ==
109       fusion_node_evaluations_.end()) {
110     // We have no cached results for this fusion node yet. This can happen when
111     // we run the InstructionFusion pass more than once. We can only cache the
112     // results within one run.
113     fusion_node_evaluations_.emplace(consumer,
114                                      FusionNodeIndexingEvaluation(consumer));
115   }
116   if (fusion_node_evaluations_.at(consumer).CodeDuplicationTooHigh(producer)) {
117     return "the fusion would result in an overly large code duplication";
118   }
119   return {};
120 }
121 
ChooseKind(const HloInstruction * producer,const HloInstruction * consumer)122 HloInstruction::FusionKind GpuInstructionFusion::ChooseKind(
123     const HloInstruction* producer, const HloInstruction* consumer) {
124   return ChooseFusionKind(*producer, *consumer);
125 }
126 
FuseInstruction(HloInstruction * fusion_instruction,HloInstruction * producer)127 HloInstruction* GpuInstructionFusion::FuseInstruction(
128     HloInstruction* fusion_instruction, HloInstruction* producer) {
129   auto evaluation = fusion_node_evaluations_.find(fusion_instruction);
130   if (evaluation == fusion_node_evaluations_.end()) {
131     evaluation = fusion_node_evaluations_
132                      .emplace(fusion_instruction,
133                               FusionNodeIndexingEvaluation(fusion_instruction))
134                      .first;
135   }
136   auto indexing_users = evaluation->second.RemoveFusionOperand(producer);
137   HloInstruction* new_producer =
138       InstructionFusion::FuseInstruction(fusion_instruction, producer);
139   evaluation->second.UpdateEvaluationCache(new_producer, indexing_users);
140   return new_producer;
141 }
142 
143 }  // namespace gpu
144 }  // namespace xla
145