• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
17 
18 #include "absl/container/flat_hash_set.h"
19 #include "absl/strings/str_cat.h"
20 #include "llvm/IR/BasicBlock.h"
21 #include "llvm/IR/Function.h"
22 #include "llvm/IR/Instructions.h"
23 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
24 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
27 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
28 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/types.h"
31 
32 namespace xla {
33 namespace gpu {
34 
35 using absl::StrAppend;
36 using absl::StrCat;
37 
EmitBasePointersForHlos(absl::Span<const HloInstruction * const> io_hlos,absl::Span<const HloInstruction * const> non_io_hlos)38 void HloToIrBindings::EmitBasePointersForHlos(
39     absl::Span<const HloInstruction* const> io_hlos,
40     absl::Span<const HloInstruction* const> non_io_hlos) {
41   CHECK(is_nested_);
42 
43   // I/O HLOs are bound to the arguments of the current IR function,
44   // *excluding* the output argument, which is added to non-I/O HLOs.
45   // I.e.,
46   //
47   // void IrFunction(io_0, io_1, ..., io_{m-1}, output_arg);
48   llvm::Function* function = b_->GetInsertBlock()->getParent();
49   CHECK_EQ(io_hlos.size() + 1, function->arg_size());
50 
51   // An HLO can have duplicated operands. This data structure remembers which
52   // operand HLOs are already bound to avoid rebinding the same HLO.
53   absl::flat_hash_set<const HloInstruction*> already_bound_for_this_function;
54   auto arg_iter = function->arg_begin();
55   for (const HloInstruction* io_hlo : io_hlos) {
56     CHECK(io_hlo == io_hlo->parent()->root_instruction() ||
57           !absl::c_count(non_io_hlos, io_hlo))
58         << "IO HLOs and non-IO HLOs should be disjoint";
59     if (!already_bound_for_this_function.contains(io_hlo)) {
60       BindHloToIrValue(*io_hlo, &*arg_iter);
61       already_bound_for_this_function.insert(io_hlo);
62     }
63     ++arg_iter;
64   }
65 
66   // Name and skip the output parameter.
67   arg_iter->setName("output_arg");
68   ++arg_iter;
69 
70   for (const HloInstruction* non_io_hlo : non_io_hlos) {
71     if (already_bound_for_this_function.contains(non_io_hlo)) {
72       continue;
73     }
74     already_bound_for_this_function.insert(non_io_hlo);
75 
76     if (non_io_hlo->opcode() == HloOpcode::kGetTupleElement) {
77       continue;
78     }
79 
80     ShapeUtil::ForEachSubshape(
81         non_io_hlo->shape(),
82         [&](const Shape& /*subshape*/, const ShapeIndex& index) {
83           if (non_io_hlo->opcode() == HloOpcode::kConstant) {
84             llvm::Value* global_for_constant = module_->getGlobalVariable(
85                 llvm_ir::ConstantHloToGlobalName(*non_io_hlo));
86             CHECK(global_for_constant)
87                 << llvm_ir::ConstantHloToGlobalName(*non_io_hlo);
88             BindHloToIrValue(*non_io_hlo, global_for_constant);
89           } else {
90             llvm::Type* pointee_type =
91                 llvm_ir::ShapeToIrType(non_io_hlo->shape(), module_);
92             BindHloToIrValue(*non_io_hlo,
93                              llvm_ir::EmitAllocaAtFunctionEntry(
94                                  pointee_type, /*name=*/"", b_),
95                              index);
96           }
97         });
98   }
99 }
100 
EmitGetTupleElement(const HloInstruction * gte,llvm::Value * base_ptr)101 llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte,
102                                                   llvm::Value* base_ptr) {
103   // TODO(b/26344050): tighten the alignment based on the real element type.
104   if (gte->operand(0)->opcode() != HloOpcode::kGetTupleElement) {
105     return llvm_ir::EmitGetTupleElement(
106         gte->shape(), gte->tuple_index(), /*alignment=*/1,
107         GetTypedIrValue(*gte->operand(0), {}, base_ptr), b_);
108   }
109   return llvm_ir::EmitGetTupleElement(
110       gte->shape(), gte->tuple_index(), /*alignment=*/1,
111       EmitGetTupleElement(gte->operand(0), base_ptr), b_);
112 }
113 
114 // Returns true if `value` has a name that should not be changed.
HasMeaningfulName(llvm::Value * value)115 static bool HasMeaningfulName(llvm::Value* value) {
116   if (auto* global = llvm::dyn_cast<llvm::GlobalValue>(value)) {
117     return global->getLinkage() != llvm::GlobalValue::PrivateLinkage;
118   }
119   return false;
120 }
121 
CastToTypedValue(const Shape & shape,llvm::Value * ir_value,llvm::IRBuilder<> * b)122 llvm::Value* CastToTypedValue(const Shape& shape, llvm::Value* ir_value,
123                               llvm::IRBuilder<>* b) {
124   llvm::Type* pointee_type =
125       llvm_ir::ShapeToIrType(shape, b->GetInsertBlock()->getModule());
126 
127   llvm::Type* dest_type = pointee_type->getPointerTo();
128 
129   llvm::Value* typed_ir_value;
130   if (llvm::isa<llvm::GlobalVariable>(ir_value)) {
131     typed_ir_value = llvm::ConstantExpr::getPointerBitCastOrAddrSpaceCast(
132         llvm::cast<llvm::GlobalVariable>(ir_value), dest_type);
133   } else {
134     typed_ir_value = b->CreatePointerBitCastOrAddrSpaceCast(
135         ir_value, pointee_type->getPointerTo());
136   }
137   return typed_ir_value;
138 }
139 
GetTypedIrValue(const HloInstruction & hlo,ShapeIndexView shape_index,llvm::Value * ir_value)140 llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo,
141                                               ShapeIndexView shape_index,
142                                               llvm::Value* ir_value) {
143   auto typed_ir_value = CastToTypedValue(
144       ShapeUtil::GetSubshape(hlo.shape(), shape_index), ir_value, b_);
145   if (!HasMeaningfulName(ir_value)) {
146     ir_value->setName(llvm_ir::IrName(&hlo, "raw"));
147   }
148   if (!HasMeaningfulName(typed_ir_value)) {
149     typed_ir_value->setName(llvm_ir::IrName(&hlo, "typed"));
150   }
151   return typed_ir_value;
152 }
153 
BindHloToIrValue(const HloInstruction & hlo,llvm::Value * ir_value,ShapeIndexView shape_index)154 void HloToIrBindings::BindHloToIrValue(const HloInstruction& hlo,
155                                        llvm::Value* ir_value,
156                                        ShapeIndexView shape_index) {
157   VLOG(2) << "Binding " << hlo.ToString();
158 
159   const Shape& hlo_shape = hlo.shape();
160   llvm::Value* typed_ir_value = GetTypedIrValue(hlo, shape_index, ir_value);
161 
162   if (!BoundToIrValue(hlo)) {
163     // Set the root of ShapeTree first before assigning the element ir value.
164     InsertOrDie(&base_ptrs_, &hlo, ShapeTree<llvm::Value*>(hlo_shape, nullptr));
165   }
166   *(base_ptrs_[&hlo].mutable_element(shape_index)) = typed_ir_value;
167 }
168 
169 // Determines whether hlo's buffers are never modified within the execution of
170 // consumer.
BuffersInvariantWithinConsumer(const HloInstruction & hlo,const HloInstruction & consumer,const BufferAssignment * buffer_assignment)171 static bool BuffersInvariantWithinConsumer(
172     const HloInstruction& hlo, const HloInstruction& consumer,
173     const BufferAssignment* buffer_assignment) {
174   // Check if consumer is inside a fusion node -- if so, "dereference" it until
175   // we get to a non-fusion node.
176   const HloInstruction* c = &consumer;
177   while (c->IsFused()) {
178     c = c->parent()->FusionInstruction();
179   }
180 
181   // If, after dereferencing c, we end up with a node that's not inside our
182   // module's top-level computation (say our node is inside a while loop), we
183   // give up on marking array as invariant, because this HLO may be run multiple
184   // times (e.g. multiple while loop iterations, or multiple invocations of a
185   // reducer's computation).  TODO(jlebar): We could relax this constraint if we
186   // emitted an llvm.invariant.group.barrier at the end of the computation.
187   return c->parent() == c->GetModule()->entry_computation() &&
188          buffer_assignment->HaveDisjointSlices(&hlo, &consumer);
189 }
190 
GetIrArray(const HloInstruction & hlo,const HloInstruction & consumer,const ShapeIndex & shape_index)191 llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo,
192                                              const HloInstruction& consumer,
193                                              const ShapeIndex& shape_index) {
194   llvm::Value* base_ptr = GetBasePointer(hlo, shape_index);
195   CHECK_NE(base_ptr, nullptr)
196       << "Buffer not assigned for shape_index " << shape_index.ToString()
197       << " of " << hlo.ToString();
198   llvm_ir::IrArray ir_array(base_ptr,
199                             ShapeUtil::GetSubshape(hlo.shape(), shape_index));
200 
201   // The GPU backend emits one kernel per top-level HLO, and LLVM views
202   // execution of one kernel as the "whole program" executed on the GPU.
203   // Therefore if hlo's output buffer is not modified within consumer, and if
204   // consumer runs hlo only once (so that it doesn't create two different
205   // outputs), then we can mark ir_array as invariant over the whole program.
206   if (!is_nested_ &&
207       BuffersInvariantWithinConsumer(hlo, consumer, buffer_assignment_)) {
208     VLOG(2) << "Marking " << hlo.name() << " as invariant within "
209             << consumer.name();
210     ir_array.MarkInvariantOverWholeProgram(&module_->getContext());
211   }
212 
213   return ir_array;
214 }
215 
UnbindAllLocalIrValues()216 void HloToIrBindings::UnbindAllLocalIrValues() {
217   std::vector<const HloInstruction*> hlos_to_unbind;
218   for (auto& key_value : base_ptrs_) {
219     if (!llvm::isa<llvm::GlobalVariable>(
220             (key_value.second.element({}))->stripPointerCasts())) {
221       hlos_to_unbind.push_back(key_value.first);
222     }
223   }
224   for (const HloInstruction* hlo_to_unbind : hlos_to_unbind) {
225     VLOG(2) << "Unbinding " << hlo_to_unbind->ToString();
226     base_ptrs_.erase(hlo_to_unbind);
227   }
228 }
229 
ToString() const230 string HloToIrBindings::ToString() const {
231   string s = StrCat("** HloToIrBindings **\n");
232   StrAppend(&s, "  is_nested_=", is_nested_, "\n");
233   StrAppend(&s,
234             "  temp_buffer_base_=", llvm_ir::DumpToString(*temp_buffer_base_),
235             "\n");
236 
237   if (base_ptrs_.empty()) {
238     return s;
239   }
240 
241   // Iterate over all computations in the module in topological order, and print
242   // out the base pointers we have in each computation in topological order.
243   for (const HloComputation* computation :
244        base_ptrs_.begin()->first->GetModule()->MakeComputationPostOrder()) {
245     bool is_first = true;
246     for (const HloInstruction* instr :
247          computation->MakeInstructionPostOrder()) {
248       auto it = base_ptrs_.find(instr);
249       if (it == base_ptrs_.end()) {
250         continue;
251       }
252       if (is_first) {
253         StrAppend(&s, "  Base pointers for computation ", computation->name(),
254                   ":\n");
255         is_first = false;
256       }
257       StrAppend(&s, "    ", instr->ToString());
258 
259       const ShapeTree<llvm::Value*>& shape_tree = it->second;
260       if (!instr->shape().IsTuple()) {
261         const llvm::Value* val = shape_tree.begin()->second;
262         StrAppend(&s, " -> ", llvm_ir::DumpToString(*val), "\n");
263         continue;
264       }
265 
266       StrAppend(&s, "\n");
267       for (auto shape_it = shape_tree.begin(); shape_it != shape_tree.end();
268            ++shape_it) {
269         llvm::Value* val = shape_it->second;
270         StrAppend(&s, "      ", shape_it->first.ToString(), " -> ",
271                   (val != nullptr ? llvm_ir::DumpToString(*val) : "null"),
272                   "\n");
273       }
274     }
275   }
276   return s;
277 }
278 
279 }  // namespace gpu
280 }  // namespace xla
281