• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
17 
18 #include "absl/container/flat_hash_set.h"
19 #include "absl/strings/str_cat.h"
20 #include "llvm/IR/BasicBlock.h"
21 #include "llvm/IR/Function.h"
22 #include "llvm/IR/Instructions.h"
23 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
24 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
25 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
26 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
27 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
28 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
29 #include "tensorflow/core/platform/logging.h"
30 #include "tensorflow/core/platform/types.h"
31 
32 namespace xla {
33 namespace gpu {
34 
35 using absl::StrAppend;
36 using absl::StrCat;
37 
EmitBasePointersForHlos(absl::Span<const HloInstruction * const> io_hlos,absl::Span<const HloInstruction * const> non_io_hlos)38 void HloToIrBindings::EmitBasePointersForHlos(
39     absl::Span<const HloInstruction* const> io_hlos,
40     absl::Span<const HloInstruction* const> non_io_hlos) {
41   CHECK(is_nested_);
42 
43   // I/O HLOs are bound to the arguments of the current IR function,
44   // *excluding* the output argument, which is added to non-I/O HLOs.
45   // I.e.,
46   //
47   // void IrFunction(io_0, io_1, ..., io_{m-1}, output_arg);
48   llvm::Function* function = b_->GetInsertBlock()->getParent();
49   CHECK_EQ(io_hlos.size() + 1, function->arg_size());
50 
51   // An HLO can have duplicated operands. This data structure remembers which
52   // operand HLOs are already bound to avoid rebinding the same HLO.
53   absl::flat_hash_set<const HloInstruction*> already_bound_for_this_function;
54   auto arg_iter = function->arg_begin();
55   for (const HloInstruction* io_hlo : io_hlos) {
56     CHECK(io_hlo == io_hlo->parent()->root_instruction() ||
57           !absl::c_count(non_io_hlos, io_hlo))
58         << "IO HLOs and non-IO HLOs should be disjoint";
59     if (!already_bound_for_this_function.contains(io_hlo)) {
60       BindHloToIrValue(*io_hlo, &*arg_iter);
61       already_bound_for_this_function.insert(io_hlo);
62     }
63     ++arg_iter;
64   }
65 
66   // Name and skip the output parameter.
67   arg_iter->setName("output_arg");
68   ++arg_iter;
69 
70   for (const HloInstruction* non_io_hlo : non_io_hlos) {
71     if (already_bound_for_this_function.contains(non_io_hlo)) {
72       continue;
73     }
74     already_bound_for_this_function.insert(non_io_hlo);
75 
76     if (non_io_hlo->opcode() == HloOpcode::kGetTupleElement) {
77       continue;
78     }
79 
80     ShapeUtil::ForEachSubshape(
81         non_io_hlo->shape(),
82         [&](const Shape& /*subshape*/, const ShapeIndex& index) {
83           if (non_io_hlo->opcode() == HloOpcode::kConstant) {
84             llvm::Value* global_for_constant = module_->getGlobalVariable(
85                 llvm_ir::ConstantHloToGlobalName(*non_io_hlo));
86             CHECK(global_for_constant)
87                 << llvm_ir::ConstantHloToGlobalName(*non_io_hlo);
88             BindHloToIrValue(*non_io_hlo, global_for_constant);
89           } else {
90             llvm::Type* pointee_type =
91                 llvm_ir::ShapeToIrType(non_io_hlo->shape(), module_);
92             BindHloToIrValue(*non_io_hlo,
93                              llvm_ir::EmitAllocaAtFunctionEntry(
94                                  pointee_type, /*name=*/"", b_),
95                              index);
96           }
97         });
98   }
99 }
100 
EmitGetTupleElement(const HloInstruction * gte,llvm::Value * base_ptr)101 llvm::Value* HloToIrBindings::EmitGetTupleElement(const HloInstruction* gte,
102                                                   llvm::Value* base_ptr) {
103   // TODO(b/26344050): tighten the alignment based on the real element type.
104   if (gte->operand(0)->opcode() != HloOpcode::kGetTupleElement) {
105     return llvm_ir::EmitGetTupleElement(
106         gte->shape(), gte->tuple_index(), /*alignment=*/1,
107         GetTypedIrValue(*gte->operand(0), {}, base_ptr), b_);
108   }
109   return llvm_ir::EmitGetTupleElement(
110       gte->shape(), gte->tuple_index(), /*alignment=*/1,
111       EmitGetTupleElement(gte->operand(0), base_ptr), b_);
112 }
113 
114 // Returns true if `value` has a name that should not be changed.
HasMeaningfulName(llvm::Value * value)115 static bool HasMeaningfulName(llvm::Value* value) {
116   if (auto* global = llvm::dyn_cast<llvm::GlobalValue>(value)) {
117     return global->getLinkage() != llvm::GlobalValue::PrivateLinkage;
118   }
119   return false;
120 }
121 
CastToTypedValue(const Shape & shape,llvm::Value * ir_value,llvm::IRBuilder<> * b)122 llvm::Value* CastToTypedValue(const Shape& shape, llvm::Value* ir_value,
123                               llvm::IRBuilder<>* b) {
124   llvm::Type* pointee_type =
125       llvm_ir::ShapeToIrType(shape, b->GetInsertBlock()->getModule());
126 
127   llvm::Type* dest_type = pointee_type->getPointerTo();
128 
129   llvm::Value* typed_ir_value;
130   if (llvm::isa<llvm::GlobalVariable>(ir_value)) {
131     typed_ir_value = llvm::ConstantExpr::getPointerBitCastOrAddrSpaceCast(
132         llvm::cast<llvm::GlobalVariable>(ir_value), dest_type);
133   } else {
134     typed_ir_value = b->CreatePointerBitCastOrAddrSpaceCast(
135         ir_value, pointee_type->getPointerTo());
136   }
137   return typed_ir_value;
138 }
139 
GetTypedIrValue(const HloInstruction & hlo,ShapeIndexView shape_index,llvm::Value * ir_value)140 llvm::Value* HloToIrBindings::GetTypedIrValue(const HloInstruction& hlo,
141                                               ShapeIndexView shape_index,
142                                               llvm::Value* ir_value) {
143   auto typed_ir_value = CastToTypedValue(
144       ShapeUtil::GetSubshape(hlo.shape(), shape_index), ir_value, b_);
145   if (!HasMeaningfulName(ir_value)) {
146     ir_value->setName(llvm_ir::IrName(&hlo, "raw"));
147   }
148   if (!HasMeaningfulName(typed_ir_value)) {
149     typed_ir_value->setName(llvm_ir::IrName(&hlo, "typed"));
150   }
151   return typed_ir_value;
152 }
153 
BindHloToIrValue(const HloInstruction & hlo,llvm::Value * ir_value,ShapeIndexView shape_index)154 void HloToIrBindings::BindHloToIrValue(const HloInstruction& hlo,
155                                        llvm::Value* ir_value,
156                                        ShapeIndexView shape_index) {
157   VLOG(2) << "Binding " << hlo.ToString();
158 
159   const Shape& hlo_shape = hlo.shape();
160   llvm::Value* typed_ir_value = GetTypedIrValue(hlo, shape_index, ir_value);
161 
162   if (!BoundToIrValue(hlo)) {
163     // Set the root of ShapeTree first before assigning the element ir value.
164     InsertOrDie(&base_ptrs_, &hlo, ShapeTree<llvm::Value*>(hlo_shape, nullptr));
165   }
166   *(base_ptrs_[&hlo].mutable_element(shape_index)) = typed_ir_value;
167 }
168 
GetIrArray(const HloInstruction & hlo,const HloInstruction & consumer,const ShapeIndex & shape_index)169 llvm_ir::IrArray HloToIrBindings::GetIrArray(const HloInstruction& hlo,
170                                              const HloInstruction& consumer,
171                                              const ShapeIndex& shape_index) {
172   CHECK(is_nested_)
173       << "IrEmitterUnnested should instead use LMHLO to get the IrArray";
174 
175   llvm::Value* base_ptr = GetBasePointer(hlo, shape_index);
176   CHECK_NE(base_ptr, nullptr)
177       << "Buffer not assigned for shape_index " << shape_index.ToString()
178       << " of " << hlo.ToString();
179   llvm_ir::IrArray ir_array(base_ptr,
180                             ShapeUtil::GetSubshape(hlo.shape(), shape_index));
181 
182   return ir_array;
183 }
184 
UnbindAllLocalIrValues()185 void HloToIrBindings::UnbindAllLocalIrValues() {
186   std::vector<const HloInstruction*> hlos_to_unbind;
187   for (auto& key_value : base_ptrs_) {
188     if (!llvm::isa<llvm::GlobalVariable>(
189             (key_value.second.element({}))->stripPointerCasts())) {
190       hlos_to_unbind.push_back(key_value.first);
191     }
192   }
193   for (const HloInstruction* hlo_to_unbind : hlos_to_unbind) {
194     VLOG(2) << "Unbinding " << hlo_to_unbind->ToString();
195     base_ptrs_.erase(hlo_to_unbind);
196   }
197 }
198 
ToString() const199 string HloToIrBindings::ToString() const {
200   string s = StrCat("** HloToIrBindings **\n");
201   StrAppend(&s, "  is_nested_=", is_nested_, "\n");
202   StrAppend(&s,
203             "  temp_buffer_base_=", llvm_ir::DumpToString(*temp_buffer_base_),
204             "\n");
205 
206   if (base_ptrs_.empty()) {
207     return s;
208   }
209 
210   // Iterate over all computations in the module in topological order, and print
211   // out the base pointers we have in each computation in topological order.
212   for (const HloComputation* computation :
213        base_ptrs_.begin()->first->GetModule()->MakeComputationPostOrder()) {
214     bool is_first = true;
215     for (const HloInstruction* instr :
216          computation->MakeInstructionPostOrder()) {
217       auto it = base_ptrs_.find(instr);
218       if (it == base_ptrs_.end()) {
219         continue;
220       }
221       if (is_first) {
222         StrAppend(&s, "  Base pointers for computation ", computation->name(),
223                   ":\n");
224         is_first = false;
225       }
226       StrAppend(&s, "    ", instr->ToString());
227 
228       const ShapeTree<llvm::Value*>& shape_tree = it->second;
229       if (!instr->shape().IsTuple()) {
230         const llvm::Value* val = shape_tree.begin()->second;
231         StrAppend(&s, " -> ", llvm_ir::DumpToString(*val), "\n");
232         continue;
233       }
234 
235       StrAppend(&s, "\n");
236       for (auto shape_it = shape_tree.begin(); shape_it != shape_tree.end();
237            ++shape_it) {
238         llvm::Value* val = shape_it->second;
239         StrAppend(&s, "      ", shape_it->first.ToString(), " -> ",
240                   (val != nullptr ? llvm_ir::DumpToString(*val) : "null"),
241                   "\n");
242       }
243     }
244   }
245   return s;
246 }
247 
248 }  // namespace gpu
249 }  // namespace xla
250