1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
17
18 #include <numeric>
19 #include <vector>
20
21 #include "absl/strings/str_cat.h"
22 #include "llvm/IR/Constants.h"
23 #include "llvm/IR/Function.h"
24 #include "llvm/IR/Instructions.h"
25 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/types.h"
28 #include "tensorflow/core/platform/logging.h"
29
30 namespace xla {
31 namespace llvm_ir {
32
ForLoop(absl::string_view prefix,absl::string_view suffix,llvm::Value * start_index,llvm::Value * end_index,llvm::Value * step,UnrollMode unroll_mode,bool prevent_vectorization)33 ForLoop::ForLoop(absl::string_view prefix, absl::string_view suffix,
34 llvm::Value* start_index, llvm::Value* end_index,
35 llvm::Value* step, UnrollMode unroll_mode,
36 bool prevent_vectorization)
37 : prefix_(prefix),
38 suffix_(suffix),
39 start_index_(start_index),
40 end_index_(end_index),
41 step_(step),
42 insert_before_bb_(nullptr),
43 unroll_mode_(unroll_mode),
44 prevent_vectorization_(prevent_vectorization) {}
45
EmitForLoop(absl::string_view prefix,llvm::Value * start_index,llvm::Value * end_index,llvm::Value * step,llvm::IRBuilder<> * b,UnrollMode unroll_mode,bool prevent_vectorization)46 /* static */ std::unique_ptr<ForLoop> ForLoop::EmitForLoop(
47 absl::string_view prefix, llvm::Value* start_index, llvm::Value* end_index,
48 llvm::Value* step, llvm::IRBuilder<>* b, UnrollMode unroll_mode,
49 bool prevent_vectorization) {
50 std::unique_ptr<ForLoop> loop(new ForLoop(prefix, /*suffix=*/"", start_index,
51 end_index, step, unroll_mode,
52 prevent_vectorization));
53 loop->Emit(b);
54 return loop;
55 }
56
Emit(llvm::IRBuilder<> * b)57 void ForLoop::Emit(llvm::IRBuilder<>* b) {
58 // The preheader block is the block the builder is currently emitting
59 // code into.
60 preheader_bb_ = b->GetInsertBlock();
61
62 llvm::BasicBlock::iterator insert_point = b->GetInsertPoint();
63 if (insert_point == preheader_bb_->end()) {
64 // We're emitting the loop at the end of a basic block. Verify there is no
65 // terminator (eg, branch) in the basic block.
66 CHECK_EQ(nullptr, preheader_bb_->getTerminator());
67
68 exit_bb_ = CreateLoopBB("loop_exit", b);
69 } else {
70 // We're emitting the loop into the middle of a basic block. splitBasicBlock
71 // requires that this basic block be well-formed (have a terminator).
72 CHECK_NE(nullptr, preheader_bb_->getTerminator());
73
74 // Split the preheader to create an exit basic block. The exit basic block
75 // will contain all instructions at or after insert_point.
76 exit_bb_ = preheader_bb_->splitBasicBlock(insert_point,
77 GetQualifiedName("loop_exit"));
78
79 // splitBasicBlock adds an unconditional branch between the split basic
80 // blocks. Remove it. An unconditional branch will be added below from the
81 // preheader to the header.
82 preheader_bb_->getTerminator()->eraseFromParent();
83 }
84 insert_before_bb_ = exit_bb_;
85
86 // Create remaining basic block which form the inside of the loop.
87 header_bb_ = CreateLoopBB("loop_header", b);
88 body_bb_ = CreateLoopBB("loop_body", b);
89
90 // Function entry basic block.
91 // Emit alloca for the induction variable. We do this at the entry to the
92 // basic block to ensure the alloc only executes once per function (we could
93 // be emitting a nested loop).
94 llvm::Function* func = preheader_bb_->getParent();
95 b->SetInsertPoint(&func->getEntryBlock(),
96 func->getEntryBlock().getFirstInsertionPt());
97 llvm::Value* indvar_address = b->CreateAlloca(
98 start_index_->getType(), nullptr, GetQualifiedName("invar_address"));
99
100 // Preheader basic block.
101 // Initialize induction variable starting index. Create branch to the header.
102 b->SetInsertPoint(preheader_bb_);
103 b->CreateStore(start_index_, indvar_address);
104 // The preheader should not have a branch yet.
105 CHECK_EQ(preheader_bb_->getTerminator(), nullptr);
106 b->CreateBr(header_bb_);
107
108 // Header basic block.
109 // Emit the loop conditional branch. Load and compare indvar with ending
110 // index and jump to loop exit if equal. Jump to body otherwise.
111 b->SetInsertPoint(header_bb_);
112 indvar_ = b->CreateLoad(indvar_address, GetQualifiedName("indvar"));
113 llvm::Value* exit_cond = b->CreateICmpUGE(indvar_, end_index_);
114 b->CreateCondBr(/*Cond=*/exit_cond,
115 /*True=*/exit_bb_, /*False=*/body_bb_);
116
117 // Body basic block.
118 // Increment indvar, store indvar, and jump to header.
119 b->SetInsertPoint(body_bb_);
120 llvm::Value* step = step_;
121 llvm::Value* indvar = indvar_;
122
123 llvm::Value* indvar_inc = b->CreateAdd(indvar, step, "invar.inc",
124 /*HasNUW=*/true, /*HasNSW=*/true);
125 b->CreateStore(indvar_inc, indvar_address);
126 llvm::BranchInst* back_branch = b->CreateBr(header_bb_);
127
128 std::vector<llvm::Metadata*> loop_metadata = GetLoopMetadata(b);
129 if (!loop_metadata.empty()) {
130 llvm::LLVMContext* ctx = &start_index_->getContext();
131 auto temp_node = llvm::MDNode::getTemporary(*ctx, llvm::None);
132 loop_metadata.insert(loop_metadata.begin(), temp_node.get());
133 auto loop_id = llvm::MDNode::get(*ctx, loop_metadata);
134 loop_id->replaceOperandWith(0, loop_id);
135 back_branch->setMetadata(llvm::LLVMContext::MD_loop, loop_id);
136 }
137
138 // Re-point the IR builder to the loop exit block.
139 b->SetInsertPoint(exit_bb_);
140 }
141
GetLoopMetadata(llvm::IRBuilder<> * b)142 std::vector<llvm::Metadata*> ForLoop::GetLoopMetadata(llvm::IRBuilder<>* b) {
143 const char* const kLlvmLoopUnrollDisableMDName = "llvm.loop.unroll.disable";
144 const char* const kLlvmLoopUnrollFullMDName = "llvm.loop.unroll.full";
145 const char* const kLlvmLoopVectorizeMDName = "llvm.loop.vectorize.enable";
146 llvm::LLVMContext* ctx = &start_index_->getContext();
147
148 std::vector<llvm::Metadata*> result;
149 if (unroll_mode_ == xla::llvm_ir::UnrollMode::kNoUnroll) {
150 result.push_back(llvm::MDNode::get(
151 *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollDisableMDName)}));
152 }
153
154 if (prevent_vectorization_) {
155 result.push_back(llvm::MDNode::get(
156 *ctx, {llvm::MDString::get(*ctx, kLlvmLoopVectorizeMDName),
157 llvm::ConstantAsMetadata::get(b->getFalse())}));
158 }
159
160 if (unroll_mode_ == xla::llvm_ir::UnrollMode::kFullyUnroll) {
161 result.push_back(llvm::MDNode::get(
162 *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollFullMDName)}));
163 }
164 return result;
165 }
166
GetQualifiedName(absl::string_view name)167 string ForLoop::GetQualifiedName(absl::string_view name) {
168 return llvm_ir::IrName(prefix_, llvm_ir::IrName(name, suffix_));
169 }
170
CreateLoopBB(absl::string_view name,llvm::IRBuilder<> * b)171 llvm::BasicBlock* ForLoop::CreateLoopBB(absl::string_view name,
172 llvm::IRBuilder<>* b) {
173 return CreateBasicBlock(insert_before_bb_, GetQualifiedName(name), b);
174 }
175
AddLoop(absl::string_view suffix,llvm::Value * start_index,llvm::Value * end_index,UnrollMode unroll_mode,bool prevent_vectorization)176 std::unique_ptr<ForLoop> ForLoopNest::AddLoop(absl::string_view suffix,
177 llvm::Value* start_index,
178 llvm::Value* end_index,
179 UnrollMode unroll_mode,
180 bool prevent_vectorization) {
181 return AddLoop(suffix, start_index, end_index, GetConstantWithIndexType(1),
182 unroll_mode, prevent_vectorization);
183 }
184
AddLoop(absl::string_view suffix,llvm::Value * start_index,llvm::Value * end_index,llvm::Value * stride,UnrollMode unroll_mode,bool prevent_vectorization)185 std::unique_ptr<ForLoop> ForLoopNest::AddLoop(
186 absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index,
187 llvm::Value* stride, UnrollMode unroll_mode, bool prevent_vectorization) {
188 if (inner_loop_body_bb_ != nullptr) {
189 // Create this loop inside the previous one.
190 b_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt());
191 }
192 std::unique_ptr<ForLoop> loop(new ForLoop(
193 /*prefix=*/name_, suffix, start_index, end_index, stride, unroll_mode,
194 prevent_vectorization));
195 loop->Emit(b_);
196
197 if (outer_loop_preheader_bb_ == nullptr) {
198 outer_loop_preheader_bb_ = loop->GetPreheaderBasicBlock();
199 }
200
201 if (outer_loop_exit_bb_ == nullptr) {
202 outer_loop_exit_bb_ = loop->GetExitBasicBlock();
203 }
204
205 inner_loop_body_bb_ = loop->GetBodyBasicBlock();
206
207 return loop;
208 }
209
AddLoop(int64 start_index,int64 end_index,absl::string_view suffix,UnrollMode unroll_mode,bool prevent_vectorization)210 std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
211 int64 end_index,
212 absl::string_view suffix,
213 UnrollMode unroll_mode,
214 bool prevent_vectorization) {
215 CHECK_LE(start_index, end_index);
216 return AddLoop(suffix, GetConstantWithIndexType(start_index),
217 GetConstantWithIndexType(end_index), unroll_mode,
218 prevent_vectorization);
219 }
220
AddLoop(int64 start_index,int64 end_index,int64 stride,absl::string_view suffix,UnrollMode unroll_mode,bool prevent_vectorization)221 std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
222 int64 end_index, int64 stride,
223 absl::string_view suffix,
224 UnrollMode unroll_mode,
225 bool prevent_vectorization) {
226 CHECK_LE(start_index, end_index);
227 return AddLoop(suffix, GetConstantWithIndexType(start_index),
228 GetConstantWithIndexType(end_index),
229 GetConstantWithIndexType(stride), unroll_mode,
230 prevent_vectorization);
231 }
232
AddLoopsForShape(const Shape & shape,absl::string_view suffix)233 IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape,
234 absl::string_view suffix) {
235 std::vector<int64> dimensions(shape.rank());
236 std::iota(dimensions.begin(), dimensions.end(), 0);
237 return IrArray::Index(AddLoopsForShapeOnDimensions(shape, dimensions, suffix),
238 shape, index_type_);
239 }
240
AddLoopsForShapeOnDimensions(const Shape & shape,absl::Span<const int64> dimensions,absl::string_view suffix)241 std::vector<llvm::Value*> ForLoopNest::AddLoopsForShapeOnDimensions(
242 const Shape& shape, absl::Span<const int64> dimensions,
243 absl::string_view suffix) {
244 std::vector<llvm::Value*> multi_index(shape.dimensions_size());
245 for (int64 dimension : dimensions) {
246 std::unique_ptr<llvm_ir::ForLoop> loop = AddLoop(
247 /*start_index=*/0,
248 /*end_index=*/shape.dimensions(dimension),
249 /*suffix=*/
250 llvm_ir::IrName(suffix, absl::StrCat(dimension)));
251 multi_index[dimension] = loop->GetIndVarValue();
252 }
253 return multi_index;
254 }
255
EmitOperandArrayLoopNest(const llvm_ir::IrArray & operand_array,int64 dimension_to_skip,absl::string_view name_suffix)256 std::vector<llvm::Value*> ForLoopNest::EmitOperandArrayLoopNest(
257 const llvm_ir::IrArray& operand_array, int64 dimension_to_skip,
258 absl::string_view name_suffix) {
259 // Prepares the dimension list we will use to emit the loop nest. Outermost
260 // loops are added first. Add loops in major-to-minor order, and skip the
261 // 'dimension_to_skip' dimension.
262 std::vector<int64> dimensions;
263 const Shape& shape = operand_array.GetShape();
264 for (int64 dimension : LayoutUtil::MinorToMajor(shape)) {
265 if (dimension != dimension_to_skip) {
266 dimensions.push_back(dimension);
267 }
268 }
269
270 // Create loop nest with one for-loop for each dimension of the
271 // output.
272 std::vector<llvm::Value*> multi_index =
273 AddLoopsForShapeOnDimensions(shape, dimensions, name_suffix);
274 // Verify every dimension except the 'dimension_to_skip' dimension was set in
275 // the index.
276 for (size_t dimension = 0; dimension < multi_index.size(); ++dimension) {
277 if (dimension == dimension_to_skip) {
278 DCHECK_EQ(nullptr, multi_index[dimension]);
279 } else {
280 DCHECK_NE(nullptr, multi_index[dimension]);
281 }
282 }
283 return multi_index;
284 }
285
286 } // namespace llvm_ir
287 } // namespace xla
288