1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
17
18 #include <algorithm>
19 #include <functional>
20
21 #include "llvm/IR/BasicBlock.h"
22 #include "llvm/IR/Value.h"
23 #include "tensorflow/compiler/xla/map_util.h"
24 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
25 #include "tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
29 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
30 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
31 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
32 #include "tensorflow/compiler/xla/shape.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/status_macros.h"
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/core/platform/logging.h"
38
39 namespace xla {
40
41 using llvm_ir::IrArray;
42
DefaultAction(const HloInstruction * hlo)43 Status FusedIrEmitter::DefaultAction(const HloInstruction* hlo) {
44 indexed_generators_[hlo] =
45 [=](const IrArray::Index& index) -> StatusOr<llvm::Value*> {
46 if (llvm::Value* generated_value = FindOrDefault(
47 generated_value_cache_[hlo], index.multidim(), nullptr)) {
48 llvm::BasicBlock* generated_value_bb = nullptr;
49 if (auto* generated_instruction =
50 llvm::dyn_cast<llvm::Instruction>(generated_value)) {
51 generated_value_bb = generated_instruction->getParent();
52 }
53 // Ideally, we should be able to reuse the cached generated value if it
54 // dominates the current insertion block. However, the check for dominance
55 // can be expensive and unreliable when the function is being constructed.
56 //
57 // It's also worth experimenting what if we don't do caching at all.
58 // LLVM's CSE or GVN should be able to easily merge common subexpressions
59 // that would be regenerated without caching. But this might increase the
60 // JIT compilation time.
61 if (generated_value_bb == nullptr ||
62 generated_value_bb == b_->GetInsertBlock()) {
63 VLOG(3) << "The cached generated value is reused.";
64 return generated_value;
65 }
66 VLOG(3) << "The cached generated value can't be reused, because it is in "
67 "a different BB ("
68 << generated_value_bb->getName().str()
69 << ") from the current insertion block ("
70 << b_->GetInsertBlock()->getName().str() << ").";
71 }
72
73 TF_ASSIGN_OR_RETURN(llvm::Value* const generated_value,
74 elemental_emitter_->MakeElementGenerator(
75 hlo, indexed_generators_)(index));
76 generated_value_cache_[hlo][index.multidim()] = generated_value;
77 return generated_value;
78 };
79 return Status::OK();
80 }
81
HandleConstant(const HloInstruction * constant)82 Status FusedIrEmitter::HandleConstant(const HloInstruction* constant) {
83 indexed_generators_[constant] = [=](const IrArray::Index& index) {
84 const Literal& literal = constant->literal();
85 llvm::Constant* initializer =
86 llvm_ir::ConvertLiteralToIrConstant(literal, module_);
87 llvm::GlobalVariable* global = new llvm::GlobalVariable(
88 *b_->GetInsertBlock()->getModule(), initializer->getType(),
89 /*isConstant=*/true,
90 /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
91 /*Initializer=*/initializer,
92 /*Name=*/"", /*InsertBefore=*/nullptr,
93 /*TLMode=*/llvm::GlobalValue::NotThreadLocal,
94 /*AddressSpace=*/0,
95 /*isExternallyInitialized=*/false);
96
97 global->setUnnamedAddr(llvm::GlobalVariable::UnnamedAddr::Global);
98 llvm::Constant* shape_constant =
99 llvm::ConstantExpr::getPointerBitCastOrAddrSpaceCast(
100 global,
101 llvm_ir::ShapeToIrType(literal.shape(), module_)->getPointerTo());
102 return IrArray(shape_constant, constant->shape())
103 .EmitReadArrayElement(index, b_, constant->name());
104 };
105
106 return Status::OK();
107 }
108
HandleGetTupleElement(const HloInstruction * get_tuple_element)109 Status FusedIrEmitter::HandleGetTupleElement(
110 const HloInstruction* get_tuple_element) {
111 return InternalError("Tuple parameters are not supported for fusion");
112 }
113
HandleParameter(const HloInstruction * parameter)114 Status FusedIrEmitter::HandleParameter(const HloInstruction* parameter) {
115 if (indexed_generators_.find(parameter) == indexed_generators_.end()) {
116 return InvalidArgument("Unbound parameter: %s", parameter->ToString());
117 }
118 return Status::OK();
119 }
120
HandleTuple(const HloInstruction * tuple)121 Status FusedIrEmitter::HandleTuple(const HloInstruction* tuple) {
122 absl::Span<HloInstruction* const> operands(tuple->operands());
123 std::vector<llvm::Type*> operand_elemental_ir_types;
124 for (HloInstruction* operand : operands) {
125 operand_elemental_ir_types.push_back(llvm_ir::PrimitiveTypeToIrType(
126 operand->shape().element_type(), module_));
127 }
128 indexed_generators_[tuple] =
129 [=](const IrArray::Index& index) -> StatusOr<llvm::Value*> {
130 llvm::Value* ret = llvm::UndefValue::get(
131 llvm::StructType::get(b_->getContext(), operand_elemental_ir_types));
132 for (size_t i = 0; i < ShapeUtil::TupleElementCount(tuple->shape()); ++i) {
133 TF_ASSIGN_OR_RETURN(llvm::Value * val_i,
134 indexed_generators_[operands[i]](index));
135 ret = b_->CreateInsertValue(ret, val_i, i);
136 }
137 return ret;
138 };
139 return Status::OK();
140 }
141
IsFusedIrEmitterInefficient(const HloInstruction * consumer,const HloInstruction * producer)142 bool FusedIrEmitter::IsFusedIrEmitterInefficient(
143 const HloInstruction* consumer, const HloInstruction* producer) {
144 if (consumer->opcode() != HloOpcode::kFusion) {
145 return false;
146 }
147 FusionNodeIndexingEvaluation eval_consumer(consumer);
148 if (producer->opcode() != HloOpcode::kFusion) {
149 return eval_consumer.CodeDuplicationTooHigh(producer);
150 }
151 // If 'producer' is a fusion node as well, also evaluate it. Pass the
152 // evaluated duplication of the fusion node if it is merged into consumer.
153 FusionNodeIndexingEvaluation eval_producer(
154 producer, eval_consumer.EvaluateEmittedInstructions(producer));
155 return eval_producer.MaxCodeDuplicationTooHigh();
156 }
157
GetGenerator(const HloInstruction * instruction)158 StatusOr<FusedIrEmitter::IndexedGenerator> FusedIrEmitter::GetGenerator(
159 const HloInstruction* instruction) {
160 std::vector<const HloInstruction*> stack;
161 stack.push_back(instruction);
162 while (!stack.empty()) {
163 const HloInstruction* instr = stack.back();
164 stack.pop_back();
165 if (indexed_generators_.count(instr)) {
166 continue;
167 }
168 for (const HloInstruction* operand : instr->operands()) {
169 stack.push_back(operand);
170 }
171 switch (instr->opcode()) {
172 case HloOpcode::kConstant:
173 TF_RETURN_IF_ERROR(HandleConstant(instr));
174 break;
175 case HloOpcode::kGetTupleElement:
176 TF_RETURN_IF_ERROR(HandleGetTupleElement(instr));
177 break;
178 case HloOpcode::kParameter:
179 TF_RETURN_IF_ERROR(HandleParameter(instr));
180 break;
181 case HloOpcode::kTuple:
182 TF_RETURN_IF_ERROR(HandleTuple(instr));
183 break;
184 default:
185 TF_RETURN_IF_ERROR(DefaultAction(instr));
186 break;
187 }
188 CHECK(indexed_generators_.count(instr));
189 }
190 return indexed_generators_.at(instruction);
191 }
192
193 } // namespace xla
194