1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
17
18 #include <algorithm>
19 #include <functional>
20
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "llvm/IR/BasicBlock.h"
24 #include "llvm/IR/Value.h"
25 #include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
29 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
30 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
31 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
32 #include "tensorflow/compiler/xla/shape.h"
33 #include "tensorflow/compiler/xla/shape_util.h"
34 #include "tensorflow/compiler/xla/status_macros.h"
35 #include "tensorflow/compiler/xla/statusor.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/core/platform/logging.h"
38
39 namespace xla {
40
41 using llvm_ir::IrArray;
42
DefaultAction(HloInstruction * hlo)43 Status FusedIrEmitter::DefaultAction(HloInstruction* hlo) {
44 indexed_generators_[hlo] =
45 [=](const IrArray::Index& index) -> StatusOr<llvm::Value*> {
46 if (generated_value_cache_[hlo].contains(index.multidim())) {
47 llvm::Value* generated_value =
48 generated_value_cache_[hlo][index.multidim()];
49 llvm::BasicBlock* generated_value_bb = nullptr;
50 if (auto* generated_instruction =
51 llvm::dyn_cast<llvm::Instruction>(generated_value)) {
52 generated_value_bb = generated_instruction->getParent();
53 }
54 // Ideally, we should be able to reuse the cached generated value if it
55 // dominates the current insertion block. However, the check for dominance
56 // can be expensive and unreliable when the function is being constructed.
57 //
58 // It's also worth experimenting what if we don't do caching at all.
59 // LLVM's CSE or GVN should be able to easily merge common subexpressions
60 // that would be regenerated without caching. But this might increase the
61 // JIT compilation time.
62 if (generated_value_bb == nullptr ||
63 generated_value_bb == b_->GetInsertBlock()) {
64 VLOG(3) << "The cached generated value is reused.";
65 return generated_value;
66 }
67 VLOG(3) << "The cached generated value can't be reused, because it is in "
68 "a different BB ("
69 << generated_value_bb->getName().str()
70 << ") from the current insertion block ("
71 << b_->GetInsertBlock()->getName().str() << ").";
72 }
73
74 TF_ASSIGN_OR_RETURN(generated_value_cache_[hlo][index.multidim()],
75 elemental_emitter_->MakeElementGenerator(
76 hlo, indexed_generators_)(index));
77 return generated_value_cache_[hlo][index.multidim()];
78 };
79 return Status::OK();
80 }
81
HandleConstant(HloInstruction * constant)82 Status FusedIrEmitter::HandleConstant(HloInstruction* constant) {
83 indexed_generators_[constant] = [=](const IrArray::Index& index) {
84 const Literal& literal = constant->literal();
85 llvm::Constant* initializer =
86 llvm_ir::ConvertLiteralToIrConstant(literal, module_);
87 llvm::GlobalVariable* global = new llvm::GlobalVariable(
88 *b_->GetInsertBlock()->getModule(), initializer->getType(),
89 /*isConstant=*/true,
90 /*Linkage=*/llvm::GlobalValue::PrivateLinkage,
91 /*Initializer=*/initializer,
92 /*Name=*/"");
93 global->setUnnamedAddr(llvm::GlobalVariable::UnnamedAddr::Global);
94 llvm::Constant* shape_constant = llvm::ConstantExpr::getBitCast(
95 global,
96 llvm_ir::ShapeToIrType(literal.shape(), module_)->getPointerTo());
97 return IrArray(shape_constant, constant->shape())
98 .EmitReadArrayElement(index, b_);
99 };
100
101 return Status::OK();
102 }
103
HandleGetTupleElement(HloInstruction * get_tuple_element)104 Status FusedIrEmitter::HandleGetTupleElement(
105 HloInstruction* get_tuple_element) {
106 auto emit_tuple_element_ptr = [=]() -> StatusOr<llvm::Value*> {
107 const HloInstruction* tuple_operand = get_tuple_element->operand(0);
108 llvm::Value* tuple_ptr;
109 if (tuple_operand->opcode() == HloOpcode::kGetTupleElement) {
110 TF_ASSIGN_OR_RETURN(tuple_ptr, non_indexed_generators_[tuple_operand]());
111 } else {
112 if (tuple_operand->opcode() != HloOpcode::kParameter) {
113 return Unimplemented(
114 "GetTupleElement fusion currently only supports parameter or "
115 "nested"
116 "GetTupleElement as tuple operand, found an exception: %s",
117 tuple_operand->name());
118 }
119 tuple_ptr =
120 GetBasePointerForFusedParameter(tuple_operand->parameter_number());
121 }
122
123 // Lookup tuple element pointer.
124 return llvm_ir::EmitGetTupleElement(get_tuple_element->shape(),
125 get_tuple_element->tuple_index(),
126 /*alignment=*/1, tuple_ptr, b_);
127 };
128
129 if (!get_tuple_element->shape().IsTuple()) {
130 indexed_generators_[get_tuple_element] =
131 [=](const IrArray::Index& index) -> StatusOr<llvm::Value*> {
132 // TODO(b/34080002) Add aliasing information to tuple element IrArray.
133 TF_ASSIGN_OR_RETURN(llvm::Value * tuple_element_ptr,
134 emit_tuple_element_ptr());
135 return IrArray(tuple_element_ptr, get_tuple_element->shape())
136 .EmitReadArrayElement(index, b_);
137 };
138 } else {
139 non_indexed_generators_[get_tuple_element] = emit_tuple_element_ptr;
140 }
141 return Status::OK();
142 }
143
HandleParameter(HloInstruction * parameter)144 Status FusedIrEmitter::HandleParameter(HloInstruction* parameter) {
145 indexed_generators_[parameter] =
146 [=](const IrArray::Index& index) -> llvm::Value* {
147 if (tiled_parameter_info_) {
148 if (llvm::Value* param_tile_buffer =
149 tiled_parameter_info_->GetBufferForParameter(
150 parameter->parameter_number())) {
151 // TODO(jlebar): Add AA metadata to this load. Tile buffers are global
152 // variables, so LLVM's points-to analysis doesn't help us much. And we
153 // want the AA info to be present before address spaces are inferred
154 // (which is pretty late in the pipeline), so even if we had
155 // address-space-based AA in LLVM, it wouldn't help us much here.
156 return b_->CreateLoad(
157 b_->CreateGEP(param_tile_buffer, {index.GetConstantWithIndexType(0),
158 tiled_parameter_info_->x(),
159 tiled_parameter_info_->y()}),
160 "tiled_buffer");
161 }
162 }
163 return GetIrArrayForFusedParameter(parameter->parameter_number())
164 .EmitReadArrayElement(index, b_);
165 };
166 return Status::OK();
167 }
168
HandleTuple(HloInstruction * tuple)169 Status FusedIrEmitter::HandleTuple(HloInstruction* tuple) {
170 absl::Span<HloInstruction* const> operands(tuple->operands());
171 std::vector<llvm::Type*> operand_elemental_ir_types;
172 for (HloInstruction* operand : operands) {
173 operand_elemental_ir_types.push_back(llvm_ir::PrimitiveTypeToIrType(
174 operand->shape().element_type(), module_));
175 }
176 indexed_generators_[tuple] =
177 [=](const IrArray::Index& index) -> StatusOr<llvm::Value*> {
178 llvm::Value* ret = llvm::UndefValue::get(
179 llvm::StructType::get(b_->getContext(), operand_elemental_ir_types));
180 for (size_t i = 0; i < ShapeUtil::TupleElementCount(tuple->shape()); ++i) {
181 TF_ASSIGN_OR_RETURN(llvm::Value * val_i,
182 indexed_generators_[operands[i]](index));
183 ret = b_->CreateInsertValue(ret, val_i, i);
184 }
185 return ret;
186 };
187 return Status::OK();
188 }
189
FinishVisit(HloInstruction * root)190 Status FusedIrEmitter::FinishVisit(HloInstruction* root) {
191 fused_root_ = root;
192 return Status::OK();
193 }
194
GetRootGenerator() const195 FusedIrEmitter::IndexedGenerator FusedIrEmitter::GetRootGenerator() const {
196 CHECK_NE(nullptr, fused_root_)
197 << "GetRootGenerator should be called after Accept.";
198 return indexed_generators_.at(fused_root_);
199 }
200
GetGenerator(const HloInstruction * instruction) const201 FusedIrEmitter::IndexedGenerator FusedIrEmitter::GetGenerator(
202 const HloInstruction* instruction) const {
203 return indexed_generators_.at(instruction);
204 }
205
IsFusedIrEmitterInefficient(const HloInstruction * consumer,const HloInstruction * producer)206 bool FusedIrEmitter::IsFusedIrEmitterInefficient(
207 const HloInstruction* consumer, const HloInstruction* producer) {
208 if (consumer->opcode() != HloOpcode::kFusion) {
209 return false;
210 }
211 // Collects for each instruction in the fusion node from which (indirect)
212 // users newly created index values are passed. Roughly speaking, we reuse
213 // index values if the shapes are equal when ignoring the element type (we may
214 // reuse also if the shape change is a bitcast, but we don't consider that
215 // here). By ignoring potential reuses our estimate whether the fusion emitter
216 // is inefficient is a bit more conservative than necessary.
217 absl::flat_hash_map<const HloInstruction*,
218 absl::flat_hash_set<const HloInstruction*>>
219 indexing_users;
220 // Stores the number of different index accesses for each instruction in the
221 // fusion node. The fusion emitter caches access with the same index, so this
222 // value indicates how many times a specific instruction will be emitted.
223 absl::flat_hash_map<const HloInstruction*, int64> index_usage_count;
224 index_usage_count[consumer] = 1;
225
226 auto evaluate_fusion_computation = [&indexing_users, &index_usage_count](
227 const HloInstruction* fusion) {
228 auto postorder =
229 fusion->fused_instructions_computation()->MakeInstructionPostOrder();
230 std::reverse(postorder.begin(), postorder.end());
231 for (const auto* instruction : postorder) {
232 if (instruction->opcode() == HloOpcode::kParameter) {
233 continue;
234 }
235 int64& total = index_usage_count[instruction];
236 if (indexing_users[instruction].empty()) {
237 total = index_usage_count[fusion];
238 } else {
239 total = 0;
240 for (const auto* user : indexing_users[instruction]) {
241 int64 weight = 1;
242 // Concatenate is special: the index differs for each operand, so
243 // in the worst case we have to deal with as many index values as
244 // the number of operands of Concatenate. By considering the worst
245 // case, we are more conservative than necessary regarding
246 // refusing to fuse.
247 if (user->opcode() == HloOpcode::kConcatenate) {
248 weight = user->operand_count();
249 }
250 total += index_usage_count[user] * weight;
251 }
252 }
253 for (const auto* operand : instruction->operands()) {
254 // For simplicity we assume that all shape and layout changing
255 // operations invalidate index reuse.
256 if (Shape::Equal().IgnoreElementType()(operand->shape(),
257 instruction->shape())) {
258 // If the index is reused, it means the operand gets index values
259 // from the same set of (indirect) users as 'instruction' itself.
260 indexing_users[operand].insert(indexing_users[instruction].begin(),
261 indexing_users[instruction].end());
262 } else {
263 // If the index is not reused, it means 'instruction' computes a
264 // new index derived from the index it gets.
265 indexing_users[operand].insert(instruction);
266 }
267 }
268 }
269 };
270 evaluate_fusion_computation(consumer);
271
272 // Also account for the 'producer' if it would be fused. Find the operand it
273 // corresponds to.
274 for (int64 operand_num = 0; operand_num < consumer->operand_count();
275 ++operand_num) {
276 if (consumer->operand(operand_num) == producer) {
277 auto instruction = consumer->fused_parameter(operand_num);
278 int64& total = index_usage_count[producer];
279 total = 0;
280 for (const auto* user : indexing_users[instruction]) {
281 total += index_usage_count[user];
282 }
283 break;
284 }
285 }
286
287 // If 'producer' is a fusion node as well, also evaluate it.
288 if (producer->opcode() == HloOpcode::kFusion) {
289 evaluate_fusion_computation(producer);
290 }
291
292 // Sum up the total number of emitted ops.
293 int64 total = 0;
294 for (const auto& entry : index_usage_count) {
295 total += entry.second;
296 }
297
298 // Check that the code duplication has at most a factor of 8 (where 8 is an
299 // arbitrary constant that seems to work).
300 return total > 8 * index_usage_count.size();
301 }
302
303 } // namespace xla
304