1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
17
18 #include "absl/container/flat_hash_set.h"
19 #include "absl/strings/str_cat.h"
20 #include "llvm/IR/BasicBlock.h"
21 #include "llvm/IR/Function.h"
22 #include "llvm/IR/Instructions.h"
23 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
24 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
27 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
28 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/types.h"
31
32 namespace xla {
33 namespace gpu {
34
35 using absl::StrAppend;
36 using absl::StrCat;
37
EmitBasePointersForHlos(absl::Span<const HloInstruction * const> io_hlos,absl::Span<const HloInstruction * const> non_io_hlos)38 void HloToIrBindings::EmitBasePointersForHlos(
39 absl::Span<const HloInstruction* const> io_hlos,
40 absl::Span<const HloInstruction* const> non_io_hlos) {
41 CHECK(is_nested_);
42
43 // I/O HLOs are bound to the arguments of the current IR function,
44 // *excluding* the output argument, which is added to non-I/O HLOs.
45 // I.e.,
46 //
47 // void IrFunction(io_0, io_1, ..., io_{m-1}, output_arg);
48 llvm::Function* function = b_->GetInsertBlock()->getParent();
49 CHECK_EQ(io_hlos.size() + 1, function->arg_size());
50
51 // An HLO can have duplicated operands. This data structure remembers which
52 // operand HLOs are already bound to avoid rebinding the same HLO.
53 absl::flat_hash_set<const HloInstruction*> already_bound_for_this_function;
54 auto arg_iter = function->arg_begin();
55 for (const HloInstruction* io_hlo : io_hlos) {
56 CHECK(io_hlo == io_hlo->parent()->root_instruction() ||
57 !absl::c_count(non_io_hlos, io_hlo))
58 << "IO HLOs and non-IO HLOs should be disjoint";
59 if (!already_bound_for_this_function.contains(io_hlo)) {
60 BindHloToIrValue(*io_hlo, &*arg_iter);
61 already_bound_for_this_function.insert(io_hlo);
62 }
63 ++arg_iter;
64 }
65
66 // Name and skip the output parameter.
67 arg_iter->setName("output_arg");
68 ++arg_iter;
69
70 for (const HloInstruction* non_io_hlo : non_io_hlos) {
71 if (already_bound_for_this_function.contains(non_io_hlo)) {
72 continue;
73 }
74 already_bound_for_this_function.insert(non_io_hlo);
75
76 if (non_io_hlo->opcode() == HloOpcode::kGetTupleElement) {
77 continue;
78 }
79
80 ShapeUtil::ForEachSubshape(
81 non_io_hlo->shape(),
82 [&](const Shape& /*subshape*/, const ShapeIndex& index) {
83 if (non_io_hlo->opcode() == HloOpcode::kConstant) {
84 llvm::Value* global_for_constant = module_->getGlobalVariable(
85 llvm_ir::ConstantHloToGlobalName(*non_io_hlo));
86 CHECK(global_for_constant)
87 << llvm_ir::ConstantHloToGlobalName(*non_io_hlo);
88 BindHloToIrValue(*non_io_hlo, global_for_constant);
89 } else {
90 llvm::Type* pointee_type =
91 llvm_ir::ShapeToIrType(non_io_hlo->shape(), module_);
92 BindHloToIrValue(*non_io_hlo,
93 llvm_ir::EmitAllocaAtFunctionEntry(
94 pointee_type, /*name=*/"", b_),
95 index);
96 }
97 });
98 }
99 }
100
EmitGetTupleElement(const HloInstruction * gte,llvm::Value * base_ptr)101 llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte,
102 llvm::Value* base_ptr) {
103 // TODO(b/26344050): tighten the alignment based on the real element type.
104 if (gte->operand(0)->opcode() != HloOpcode::kGetTupleElement) {
105 return llvm_ir::EmitGetTupleElement(
106 gte->shape(), gte->tuple_index(), /*alignment=*/1,
107 GetTypedIrValue(*gte->operand(0), {}, base_ptr), b_);
108 }
109 return llvm_ir::EmitGetTupleElement(
110 gte->shape(), gte->tuple_index(), /*alignment=*/1,
111 EmitGetTupleElement(gte->operand(0), base_ptr), b_);
112 }
113
114 // Returns true if `value` has a name that should not be changed.
HasMeaningfulName(llvm::Value * value)115 static bool HasMeaningfulName(llvm::Value* value) {
116 if (auto* global = llvm::dyn_cast<llvm::GlobalValue>(value)) {
117 return global->getLinkage() != llvm::GlobalValue::PrivateLinkage;
118 }
119 return false;
120 }
121
CastToTypedValue(const Shape & shape,llvm::Value * ir_value,llvm::IRBuilder<> * b)122 llvm::Value* CastToTypedValue(const Shape& shape, llvm::Value* ir_value,
123 llvm::IRBuilder<>* b) {
124 llvm::Type* pointee_type =
125 llvm_ir::ShapeToIrType(shape, b->GetInsertBlock()->getModule());
126
127 llvm::Type* dest_type = pointee_type->getPointerTo();
128
129 llvm::Value* typed_ir_value;
130 if (llvm::isa<llvm::GlobalVariable>(ir_value)) {
131 typed_ir_value = llvm::ConstantExpr::getPointerBitCastOrAddrSpaceCast(
132 llvm::cast<llvm::GlobalVariable>(ir_value), dest_type);
133 } else {
134 typed_ir_value = b->CreatePointerBitCastOrAddrSpaceCast(
135 ir_value, pointee_type->getPointerTo());
136 }
137 return typed_ir_value;
138 }
139
GetTypedIrValue(const HloInstruction & hlo,ShapeIndexView shape_index,llvm::Value * ir_value)140 llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo,
141 ShapeIndexView shape_index,
142 llvm::Value* ir_value) {
143 auto typed_ir_value = CastToTypedValue(
144 ShapeUtil::GetSubshape(hlo.shape(), shape_index), ir_value, b_);
145 if (!HasMeaningfulName(ir_value)) {
146 ir_value->setName(llvm_ir::IrName(&hlo, "raw"));
147 }
148 if (!HasMeaningfulName(typed_ir_value)) {
149 typed_ir_value->setName(llvm_ir::IrName(&hlo, "typed"));
150 }
151 return typed_ir_value;
152 }
153
BindHloToIrValue(const HloInstruction & hlo,llvm::Value * ir_value,ShapeIndexView shape_index)154 void HloToIrBindings::BindHloToIrValue(const HloInstruction& hlo,
155 llvm::Value* ir_value,
156 ShapeIndexView shape_index) {
157 VLOG(2) << "Binding " << hlo.ToString();
158
159 const Shape& hlo_shape = hlo.shape();
160 llvm::Value* typed_ir_value = GetTypedIrValue(hlo, shape_index, ir_value);
161
162 if (!BoundToIrValue(hlo)) {
163 // Set the root of ShapeTree first before assigning the element ir value.
164 InsertOrDie(&base_ptrs_, &hlo, ShapeTree<llvm::Value*>(hlo_shape, nullptr));
165 }
166 *(base_ptrs_[&hlo].mutable_element(shape_index)) = typed_ir_value;
167 }
168
169 // Determines whether hlo's buffers are never modified within the execution of
170 // consumer.
BuffersInvariantWithinConsumer(const HloInstruction & hlo,const HloInstruction & consumer,const BufferAssignment * buffer_assignment)171 static bool BuffersInvariantWithinConsumer(
172 const HloInstruction& hlo, const HloInstruction& consumer,
173 const BufferAssignment* buffer_assignment) {
174 // Check if consumer is inside a fusion node -- if so, "dereference" it until
175 // we get to a non-fusion node.
176 const HloInstruction* c = &consumer;
177 while (c->IsFused()) {
178 c = c->parent()->FusionInstruction();
179 }
180
181 // If, after dereferencing c, we end up with a node that's not inside our
182 // module's top-level computation (say our node is inside a while loop), we
183 // give up on marking array as invariant, because this HLO may be run multiple
184 // times (e.g. multiple while loop iterations, or multiple invocations of a
185 // reducer's computation). TODO(jlebar): We could relax this constraint if we
186 // emitted an llvm.invariant.group.barrier at the end of the computation.
187 return c->parent() == c->GetModule()->entry_computation() &&
188 buffer_assignment->HaveDisjointSlices(&hlo, &consumer);
189 }
190
GetIrArray(const HloInstruction & hlo,const HloInstruction & consumer,const ShapeIndex & shape_index)191 llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo,
192 const HloInstruction& consumer,
193 const ShapeIndex& shape_index) {
194 llvm::Value* base_ptr = GetBasePointer(hlo, shape_index);
195 CHECK_NE(base_ptr, nullptr)
196 << "Buffer not assigned for shape_index " << shape_index.ToString()
197 << " of " << hlo.ToString();
198 llvm_ir::IrArray ir_array(base_ptr,
199 ShapeUtil::GetSubshape(hlo.shape(), shape_index));
200
201 // The GPU backend emits one kernel per top-level HLO, and LLVM views
202 // execution of one kernel as the "whole program" executed on the GPU.
203 // Therefore if hlo's output buffer is not modified within consumer, and if
204 // consumer runs hlo only once (so that it doesn't create two different
205 // outputs), then we can mark ir_array as invariant over the whole program.
206 if (!is_nested_ &&
207 BuffersInvariantWithinConsumer(hlo, consumer, buffer_assignment_)) {
208 VLOG(2) << "Marking " << hlo.name() << " as invariant within "
209 << consumer.name();
210 ir_array.MarkInvariantOverWholeProgram(&module_->getContext());
211 }
212
213 return ir_array;
214 }
215
UnbindAllLocalIrValues()216 void HloToIrBindings::UnbindAllLocalIrValues() {
217 std::vector<const HloInstruction*> hlos_to_unbind;
218 for (auto& key_value : base_ptrs_) {
219 if (!llvm::isa<llvm::GlobalVariable>(
220 (key_value.second.element({}))->stripPointerCasts())) {
221 hlos_to_unbind.push_back(key_value.first);
222 }
223 }
224 for (const HloInstruction* hlo_to_unbind : hlos_to_unbind) {
225 VLOG(2) << "Unbinding " << hlo_to_unbind->ToString();
226 base_ptrs_.erase(hlo_to_unbind);
227 }
228 }
229
ToString() const230 string HloToIrBindings::ToString() const {
231 string s = StrCat("** HloToIrBindings **\n");
232 StrAppend(&s, " is_nested_=", is_nested_, "\n");
233 StrAppend(&s,
234 " temp_buffer_base_=", llvm_ir::DumpToString(*temp_buffer_base_),
235 "\n");
236
237 if (base_ptrs_.empty()) {
238 return s;
239 }
240
241 // Iterate over all computations in the module in topological order, and print
242 // out the base pointers we have in each computation in topological order.
243 for (const HloComputation* computation :
244 base_ptrs_.begin()->first->GetModule()->MakeComputationPostOrder()) {
245 bool is_first = true;
246 for (const HloInstruction* instr :
247 computation->MakeInstructionPostOrder()) {
248 auto it = base_ptrs_.find(instr);
249 if (it == base_ptrs_.end()) {
250 continue;
251 }
252 if (is_first) {
253 StrAppend(&s, " Base pointers for computation ", computation->name(),
254 ":\n");
255 is_first = false;
256 }
257 StrAppend(&s, " ", instr->ToString());
258
259 const ShapeTree<llvm::Value*>& shape_tree = it->second;
260 if (!instr->shape().IsTuple()) {
261 const llvm::Value* val = shape_tree.begin()->second;
262 StrAppend(&s, " -> ", llvm_ir::DumpToString(*val), "\n");
263 continue;
264 }
265
266 StrAppend(&s, "\n");
267 for (auto shape_it = shape_tree.begin(); shape_it != shape_tree.end();
268 ++shape_it) {
269 llvm::Value* val = shape_it->second;
270 StrAppend(&s, " ", shape_it->first.ToString(), " -> ",
271 (val != nullptr ? llvm_ir::DumpToString(*val) : "null"),
272 "\n");
273 }
274 }
275 }
276 return s;
277 }
278
279 } // namespace gpu
280 } // namespace xla
281