1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
17
18 #include <algorithm>
19 #include <functional>
20
21 #include "llvm/IR/BasicBlock.h"
22 #include "llvm/IR/Value.h"
23 #include "tensorflow/compiler/xla/map_util.h"
24 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
25 #include "tensorflow/compiler/xla/service/fusion_node_indexing_evaluation.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
29 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
30 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
31 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
32 #include "tensorflow/compiler/xla/shape.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/status_macros.h"
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/core/platform/logging.h"
38
39 namespace xla {
40
41 using llvm_ir::IrArray;
42
DefaultAction(const HloInstruction * hlo)43 Status FusedIrEmitter::DefaultAction(const HloInstruction* hlo) {
44 indexed_generators_[hlo] =
45 [=](const IrArray::Index& index) -> StatusOr<llvm::Value*> {
46 if (llvm::Value* generated_value = FindOrDefault(
47 generated_value_cache_[hlo], index.multidim(), nullptr)) {
48 llvm::BasicBlock* generated_value_bb = nullptr;
49 if (auto* generated_instruction =
50 llvm::dyn_cast<llvm::Instruction>(generated_value)) {
51 generated_value_bb = generated_instruction->getParent();
52 }
53 // Ideally, we should be able to reuse the cached generated value if it
54 // dominates the current insertion block. However, the check for dominance
55 // can be expensive and unreliable when the function is being constructed.
56 //
57 // It's also worth experimenting what if we don't do caching at all.
58 // LLVM's CSE or GVN should be able to easily merge common subexpressions
59 // that would be regenerated without caching. But this might increase the
60 // JIT compilation time.
61 if (generated_value_bb == nullptr ||
62 generated_value_bb == b_->GetInsertBlock()) {
63 VLOG(3) << "The cached generated value is reused.";
64 return generated_value;
65 }
66 VLOG(3) << "The cached generated value can't be reused, because it is in "
67 "a different BB ("
68 << generated_value_bb->getName().str()
69 << ") from the current insertion block ("
70 << b_->GetInsertBlock()->getName().str() << ").";
71 }
72
73 TF_ASSIGN_OR_RETURN(llvm::Value* const generated_value,
74 elemental_emitter_->MakeElementGenerator(
75 hlo, indexed_generators_)(index));
76 generated_value_cache_[hlo][index.multidim()] = generated_value;
77 return generated_value;
78 };
79 return Status::OK();
80 }
81
HandleConstant(const HloInstruction * constant)82 Status FusedIrEmitter::HandleConstant(const HloInstruction* constant) {
83 unsigned global_address_space =
84 llvm_ir::GetGlobalMemoryAddressSpace(*module_);
85 indexed_generators_[constant] = [=](const IrArray::Index& index) {
86 const Literal& literal = constant->literal();
87 llvm::Constant* initializer =
88 llvm_ir::ConvertLiteralToIrConstant(literal, module_);
89 llvm::GlobalVariable* global = new llvm::GlobalVariable(
90 *b_->GetInsertBlock()->getModule(), initializer->getType(),
91 /*isConstant=*/true,
92 /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
93 /*Initializer=*/initializer,
94 /*Name=*/"", /*InsertBefore=*/nullptr,
95 /*TLMode=*/llvm::GlobalValue::NotThreadLocal,
96 /*AddressSpace=*/global_address_space,
97 /*isExternallyInitialized=*/false);
98
99 global->setUnnamedAddr(llvm::GlobalVariable::UnnamedAddr::Global);
100 llvm::Constant* shape_constant =
101 llvm::ConstantExpr::getPointerBitCastOrAddrSpaceCast(
102 global,
103 llvm_ir::ShapeToIrType(literal.shape(), module_)->getPointerTo());
104 return IrArray(shape_constant, constant->shape())
105 .EmitReadArrayElement(index, b_, constant->name());
106 };
107
108 return Status::OK();
109 }
110
HandleGetTupleElement(const HloInstruction * get_tuple_element)111 Status FusedIrEmitter::HandleGetTupleElement(
112 const HloInstruction* get_tuple_element) {
113 return InternalError("Tuple parameters are not supported for fusion");
114 }
115
HandleParameter(const HloInstruction * parameter)116 Status FusedIrEmitter::HandleParameter(const HloInstruction* parameter) {
117 if (indexed_generators_.find(parameter) == indexed_generators_.end()) {
118 return InvalidArgument("Unbound parameter: %s", parameter->ToString());
119 }
120 return Status::OK();
121 }
122
HandleTuple(const HloInstruction * tuple)123 Status FusedIrEmitter::HandleTuple(const HloInstruction* tuple) {
124 absl::Span<HloInstruction* const> operands(tuple->operands());
125 std::vector<llvm::Type*> operand_elemental_ir_types;
126 for (HloInstruction* operand : operands) {
127 operand_elemental_ir_types.push_back(llvm_ir::PrimitiveTypeToIrType(
128 operand->shape().element_type(), module_));
129 }
130 indexed_generators_[tuple] =
131 [=](const IrArray::Index& index) -> StatusOr<llvm::Value*> {
132 llvm::Value* ret = llvm::UndefValue::get(
133 llvm::StructType::get(b_->getContext(), operand_elemental_ir_types));
134 for (size_t i = 0; i < ShapeUtil::TupleElementCount(tuple->shape()); ++i) {
135 TF_ASSIGN_OR_RETURN(llvm::Value * val_i,
136 indexed_generators_[operands[i]](index));
137 ret = b_->CreateInsertValue(ret, val_i, i);
138 }
139 return ret;
140 };
141 return Status::OK();
142 }
143
IsFusedIrEmitterInefficient(const HloInstruction * consumer,const HloInstruction * producer)144 bool FusedIrEmitter::IsFusedIrEmitterInefficient(
145 const HloInstruction* consumer, const HloInstruction* producer) {
146 if (consumer->opcode() != HloOpcode::kFusion) {
147 return false;
148 }
149 FusionNodeIndexingEvaluation eval_consumer(consumer);
150 if (producer->opcode() != HloOpcode::kFusion) {
151 return eval_consumer.CodeDuplicationTooHigh(producer);
152 }
153 // If 'producer' is a fusion node as well, also evaluate it. Pass the
154 // evaluated duplication of the fusion node if it is merged into consumer.
155 FusionNodeIndexingEvaluation eval_producer(
156 producer, eval_consumer.EvaluateEmittedInstructions(producer));
157 return eval_producer.MaxCodeDuplicationTooHigh();
158 }
159
GetGenerator(const HloInstruction * instruction)160 StatusOr<FusedIrEmitter::IndexedGenerator> FusedIrEmitter::GetGenerator(
161 const HloInstruction* instruction) {
162 std::vector<const HloInstruction*> stack;
163 stack.push_back(instruction);
164 while (!stack.empty()) {
165 const HloInstruction* instr = stack.back();
166 stack.pop_back();
167 if (indexed_generators_.count(instr)) {
168 continue;
169 }
170 for (const HloInstruction* operand : instr->operands()) {
171 stack.push_back(operand);
172 }
173 switch (instr->opcode()) {
174 case HloOpcode::kConstant:
175 TF_RETURN_IF_ERROR(HandleConstant(instr));
176 break;
177 case HloOpcode::kGetTupleElement:
178 TF_RETURN_IF_ERROR(HandleGetTupleElement(instr));
179 break;
180 case HloOpcode::kParameter:
181 TF_RETURN_IF_ERROR(HandleParameter(instr));
182 break;
183 case HloOpcode::kTuple:
184 TF_RETURN_IF_ERROR(HandleTuple(instr));
185 break;
186 default:
187 TF_RETURN_IF_ERROR(DefaultAction(instr));
188 break;
189 }
190 CHECK(indexed_generators_.count(instr));
191 }
192 return indexed_generators_.at(instruction);
193 }
194
195 } // namespace xla
196