1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
17
18 #include "absl/container/flat_hash_set.h"
19 #include "absl/strings/str_cat.h"
20 #include "llvm/IR/BasicBlock.h"
21 #include "llvm/IR/Function.h"
22 #include "llvm/IR/Instructions.h"
23 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
24 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
27 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
28 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/types.h"
31
32 namespace xla {
33 namespace gpu {
34
35 using absl::StrAppend;
36 using absl::StrCat;
37
EmitBasePointersForHlos(absl::Span<const HloInstruction * const> io_hlos,absl::Span<const HloInstruction * const> non_io_hlos)38 void HloToIrBindings::EmitBasePointersForHlos(
39 absl::Span<const HloInstruction* const> io_hlos,
40 absl::Span<const HloInstruction* const> non_io_hlos) {
41 CHECK(is_nested_);
42
43 // I/O HLOs are bound to the arguments of the current IR function,
44 // *excluding* the output argument, which is added to non-I/O HLOs.
45 // I.e.,
46 //
47 // void IrFunction(io_0, io_1, ..., io_{m-1}, output_arg);
48 llvm::Function* function = b_->GetInsertBlock()->getParent();
49 CHECK_EQ(io_hlos.size() + 1, function->arg_size());
50
51 // An HLO can have duplicated operands. This data structure remembers which
52 // operand HLOs are already bound to avoid rebinding the same HLO.
53 absl::flat_hash_set<const HloInstruction*> already_bound_for_this_function;
54 auto arg_iter = function->arg_begin();
55 for (const HloInstruction* io_hlo : io_hlos) {
56 CHECK(io_hlo == io_hlo->parent()->root_instruction() ||
57 !absl::c_count(non_io_hlos, io_hlo))
58 << "IO HLOs and non-IO HLOs should be disjoint";
59 if (!already_bound_for_this_function.contains(io_hlo)) {
60 BindHloToIrValue(*io_hlo, &*arg_iter);
61 already_bound_for_this_function.insert(io_hlo);
62 }
63 ++arg_iter;
64 }
65
66 // Name and skip the output parameter.
67 arg_iter->setName("output_arg");
68 ++arg_iter;
69
70 for (const HloInstruction* non_io_hlo : non_io_hlos) {
71 if (already_bound_for_this_function.contains(non_io_hlo)) {
72 continue;
73 }
74 already_bound_for_this_function.insert(non_io_hlo);
75
76 if (non_io_hlo->opcode() == HloOpcode::kGetTupleElement) {
77 continue;
78 }
79
80 ShapeUtil::ForEachSubshape(
81 non_io_hlo->shape(),
82 [&](const Shape& /*subshape*/, const ShapeIndex& index) {
83 if (non_io_hlo->opcode() == HloOpcode::kConstant) {
84 llvm::Value* global_for_constant = module_->getGlobalVariable(
85 llvm_ir::ConstantHloToGlobalName(*non_io_hlo));
86 CHECK(global_for_constant)
87 << llvm_ir::ConstantHloToGlobalName(*non_io_hlo);
88 BindHloToIrValue(*non_io_hlo, global_for_constant);
89 } else {
90 llvm::Type* pointee_type =
91 llvm_ir::ShapeToIrType(non_io_hlo->shape(), module_);
92 BindHloToIrValue(*non_io_hlo,
93 llvm_ir::EmitAllocaAtFunctionEntry(
94 pointee_type, /*name=*/"", b_),
95 index);
96 }
97 });
98 }
99 }
100
EmitGetTupleElement(const HloInstruction * gte,llvm::Value * base_ptr)101 llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte,
102 llvm::Value* base_ptr) {
103 // TODO(b/26344050): tighten the alignment based on the real element type.
104 if (gte->operand(0)->opcode() != HloOpcode::kGetTupleElement) {
105 return llvm_ir::EmitGetTupleElement(
106 gte->shape(), gte->tuple_index(), /*alignment=*/1,
107 GetTypedIrValue(*gte->operand(0), {}, base_ptr), b_);
108 }
109 return llvm_ir::EmitGetTupleElement(
110 gte->shape(), gte->tuple_index(), /*alignment=*/1,
111 EmitGetTupleElement(gte->operand(0), base_ptr), b_);
112 }
113
114 // Returns true if `value` has a name that should not be changed.
HasMeaningfulName(llvm::Value * value)115 static bool HasMeaningfulName(llvm::Value* value) {
116 if (auto* global = llvm::dyn_cast<llvm::GlobalValue>(value)) {
117 return global->getLinkage() != llvm::GlobalValue::PrivateLinkage;
118 }
119 return false;
120 }
121
CastToTypedValue(const Shape & shape,llvm::Value * ir_value,llvm::IRBuilder<> * b)122 llvm::Value* CastToTypedValue(const Shape& shape, llvm::Value* ir_value,
123 llvm::IRBuilder<>* b) {
124 llvm::Type* pointee_type =
125 llvm_ir::ShapeToIrType(shape, b->GetInsertBlock()->getModule());
126
127 llvm::Type* dest_type = pointee_type->getPointerTo();
128
129 llvm::Value* typed_ir_value;
130 if (llvm::isa<llvm::GlobalVariable>(ir_value)) {
131 typed_ir_value = llvm::ConstantExpr::getPointerBitCastOrAddrSpaceCast(
132 llvm::cast<llvm::GlobalVariable>(ir_value), dest_type);
133 } else {
134 typed_ir_value = b->CreatePointerBitCastOrAddrSpaceCast(
135 ir_value, pointee_type->getPointerTo());
136 }
137 return typed_ir_value;
138 }
139
GetTypedIrValue(const HloInstruction & hlo,ShapeIndexView shape_index,llvm::Value * ir_value)140 llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo,
141 ShapeIndexView shape_index,
142 llvm::Value* ir_value) {
143 auto typed_ir_value = CastToTypedValue(
144 ShapeUtil::GetSubshape(hlo.shape(), shape_index), ir_value, b_);
145 if (!HasMeaningfulName(ir_value)) {
146 ir_value->setName(llvm_ir::IrName(&hlo, "raw"));
147 }
148 if (!HasMeaningfulName(typed_ir_value)) {
149 typed_ir_value->setName(llvm_ir::IrName(&hlo, "typed"));
150 }
151 return typed_ir_value;
152 }
153
BindHloToIrValue(const HloInstruction & hlo,llvm::Value * ir_value,ShapeIndexView shape_index)154 void HloToIrBindings::BindHloToIrValue(const HloInstruction& hlo,
155 llvm::Value* ir_value,
156 ShapeIndexView shape_index) {
157 VLOG(2) << "Binding " << hlo.ToString();
158
159 const Shape& hlo_shape = hlo.shape();
160 llvm::Value* typed_ir_value = GetTypedIrValue(hlo, shape_index, ir_value);
161
162 if (!BoundToIrValue(hlo)) {
163 // Set the root of ShapeTree first before assigning the element ir value.
164 InsertOrDie(&base_ptrs_, &hlo, ShapeTree<llvm::Value*>(hlo_shape, nullptr));
165 }
166 *(base_ptrs_[&hlo].mutable_element(shape_index)) = typed_ir_value;
167 }
168
GetIrArray(const HloInstruction & hlo,const HloInstruction & consumer,const ShapeIndex & shape_index)169 llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo,
170 const HloInstruction& consumer,
171 const ShapeIndex& shape_index) {
172 CHECK(is_nested_)
173 << "IrEmitterUnnested should instead use LMHLO to get the IrArray";
174
175 llvm::Value* base_ptr = GetBasePointer(hlo, shape_index);
176 CHECK_NE(base_ptr, nullptr)
177 << "Buffer not assigned for shape_index " << shape_index.ToString()
178 << " of " << hlo.ToString();
179 llvm_ir::IrArray ir_array(base_ptr,
180 ShapeUtil::GetSubshape(hlo.shape(), shape_index));
181
182 return ir_array;
183 }
184
UnbindAllLocalIrValues()185 void HloToIrBindings::UnbindAllLocalIrValues() {
186 std::vector<const HloInstruction*> hlos_to_unbind;
187 for (auto& key_value : base_ptrs_) {
188 if (!llvm::isa<llvm::GlobalVariable>(
189 (key_value.second.element({}))->stripPointerCasts())) {
190 hlos_to_unbind.push_back(key_value.first);
191 }
192 }
193 for (const HloInstruction* hlo_to_unbind : hlos_to_unbind) {
194 VLOG(2) << "Unbinding " << hlo_to_unbind->ToString();
195 base_ptrs_.erase(hlo_to_unbind);
196 }
197 }
198
ToString() const199 string HloToIrBindings::ToString() const {
200 string s = StrCat("** HloToIrBindings **\n");
201 StrAppend(&s, " is_nested_=", is_nested_, "\n");
202 StrAppend(&s,
203 " temp_buffer_base_=", llvm_ir::DumpToString(*temp_buffer_base_),
204 "\n");
205
206 if (base_ptrs_.empty()) {
207 return s;
208 }
209
210 // Iterate over all computations in the module in topological order, and print
211 // out the base pointers we have in each computation in topological order.
212 for (const HloComputation* computation :
213 base_ptrs_.begin()->first->GetModule()->MakeComputationPostOrder()) {
214 bool is_first = true;
215 for (const HloInstruction* instr :
216 computation->MakeInstructionPostOrder()) {
217 auto it = base_ptrs_.find(instr);
218 if (it == base_ptrs_.end()) {
219 continue;
220 }
221 if (is_first) {
222 StrAppend(&s, " Base pointers for computation ", computation->name(),
223 ":\n");
224 is_first = false;
225 }
226 StrAppend(&s, " ", instr->ToString());
227
228 const ShapeTree<llvm::Value*>& shape_tree = it->second;
229 if (!instr->shape().IsTuple()) {
230 const llvm::Value* val = shape_tree.begin()->second;
231 StrAppend(&s, " -> ", llvm_ir::DumpToString(*val), "\n");
232 continue;
233 }
234
235 StrAppend(&s, "\n");
236 for (auto shape_it = shape_tree.begin(); shape_it != shape_tree.end();
237 ++shape_it) {
238 llvm::Value* val = shape_it->second;
239 StrAppend(&s, " ", shape_it->first.ToString(), " -> ",
240 (val != nullptr ? llvm_ir::DumpToString(*val) : "null"),
241 "\n");
242 }
243 }
244 }
245 return s;
246 }
247
248 } // namespace gpu
249 } // namespace xla
250