1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
17
18 #include <numeric>
19 #include <vector>
20
21 #include "absl/algorithm/container.h"
22 #include "absl/strings/str_cat.h"
23 #include "llvm/IR/Constants.h"
24 #include "llvm/IR/Function.h"
25 #include "llvm/IR/Instructions.h"
26 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/core/platform/logging.h"
30
31 namespace xla {
32 namespace llvm_ir {
33
ForLoop(absl::string_view prefix,absl::string_view suffix,llvm::Value * start_index,llvm::Value * end_index,llvm::Value * step,UnrollMode unroll_mode,bool prevent_vectorization)34 ForLoop::ForLoop(absl::string_view prefix, absl::string_view suffix,
35 llvm::Value* start_index, llvm::Value* end_index,
36 llvm::Value* step, UnrollMode unroll_mode,
37 bool prevent_vectorization)
38 : prefix_(prefix),
39 suffix_(suffix),
40 start_index_(start_index),
41 end_index_(end_index),
42 step_(step),
43 insert_before_bb_(nullptr),
44 unroll_mode_(unroll_mode),
45 prevent_vectorization_(prevent_vectorization) {}
46
EmitForLoop(absl::string_view prefix,llvm::Value * start_index,llvm::Value * end_index,llvm::Value * step,llvm::IRBuilder<> * b,UnrollMode unroll_mode,bool prevent_vectorization)47 /* static */ std::unique_ptr<ForLoop> ForLoop::EmitForLoop(
48 absl::string_view prefix, llvm::Value* start_index, llvm::Value* end_index,
49 llvm::Value* step, llvm::IRBuilder<>* b, UnrollMode unroll_mode,
50 bool prevent_vectorization) {
51 std::unique_ptr<ForLoop> loop(new ForLoop(prefix, /*suffix=*/"", start_index,
52 end_index, step, unroll_mode,
53 prevent_vectorization));
54 loop->Emit(b);
55 return loop;
56 }
57
Emit(llvm::IRBuilder<> * b)58 void ForLoop::Emit(llvm::IRBuilder<>* b) {
59 // The preheader block is the block the builder is currently emitting
60 // code into.
61 preheader_bb_ = b->GetInsertBlock();
62
63 llvm::BasicBlock::iterator insert_point = b->GetInsertPoint();
64 if (insert_point == preheader_bb_->end()) {
65 // We're emitting the loop at the end of a basic block. Verify there is no
66 // terminator (eg, branch) in the basic block.
67 CHECK_EQ(nullptr, preheader_bb_->getTerminator());
68
69 exit_bb_ = CreateLoopBB("loop_exit", b);
70 } else {
71 // We're emitting the loop into the middle of a basic block. splitBasicBlock
72 // requires that this basic block be well-formed (have a terminator).
73 CHECK_NE(nullptr, preheader_bb_->getTerminator());
74
75 // Split the preheader to create an exit basic block. The exit basic block
76 // will contain all instructions at or after insert_point.
77 exit_bb_ = preheader_bb_->splitBasicBlock(insert_point,
78 GetQualifiedName("loop_exit"));
79
80 // splitBasicBlock adds an unconditional branch between the split basic
81 // blocks. Remove it. An unconditional branch will be added below from the
82 // preheader to the header.
83 preheader_bb_->getTerminator()->eraseFromParent();
84 }
85 insert_before_bb_ = exit_bb_;
86
87 // Create remaining basic block which form the inside of the loop.
88 header_bb_ = CreateLoopBB("loop_header", b);
89 body_bb_ = CreateLoopBB("loop_body", b);
90
91 // Function entry basic block.
92 // Emit alloca for the induction variable. We do this at the entry to the
93 // basic block to ensure the alloc only executes once per function (we could
94 // be emitting a nested loop).
95 llvm::Function* func = preheader_bb_->getParent();
96 b->SetInsertPoint(&func->getEntryBlock(),
97 func->getEntryBlock().getFirstInsertionPt());
98 llvm::Value* indvar_address = b->CreateAlloca(
99 start_index_->getType(), nullptr, GetQualifiedName("invar_address"));
100
101 // Preheader basic block.
102 // Initialize induction variable starting index. Create branch to the header.
103 b->SetInsertPoint(preheader_bb_);
104 b->CreateStore(start_index_, indvar_address);
105 // The preheader should not have a branch yet.
106 CHECK_EQ(preheader_bb_->getTerminator(), nullptr);
107 b->CreateBr(header_bb_);
108
109 // Header basic block.
110 // Emit the loop conditional branch. Load and compare indvar with ending
111 // index and jump to loop exit if equal. Jump to body otherwise.
112 b->SetInsertPoint(header_bb_);
113 indvar_ = b->CreateLoad(start_index_->getType(), indvar_address,
114 GetQualifiedName("indvar"));
115 llvm::Value* exit_cond = b->CreateICmpUGE(indvar_, end_index_);
116 b->CreateCondBr(/*Cond=*/exit_cond,
117 /*True=*/exit_bb_, /*False=*/body_bb_);
118
119 // Body basic block.
120 // Increment indvar, store indvar, and jump to header.
121 b->SetInsertPoint(body_bb_);
122 llvm::Value* step = step_;
123 llvm::Value* indvar = indvar_;
124
125 llvm::Value* indvar_inc = b->CreateAdd(indvar, step, "invar.inc",
126 /*HasNUW=*/true, /*HasNSW=*/true);
127 b->CreateStore(indvar_inc, indvar_address);
128 llvm::BranchInst* back_branch = b->CreateBr(header_bb_);
129
130 std::vector<llvm::Metadata*> loop_metadata = GetLoopMetadata(b);
131 if (!loop_metadata.empty()) {
132 llvm::LLVMContext* ctx = &start_index_->getContext();
133 auto temp_node = llvm::MDNode::getTemporary(*ctx, llvm::None);
134 loop_metadata.insert(loop_metadata.begin(), temp_node.get());
135 auto loop_id = llvm::MDNode::get(*ctx, loop_metadata);
136 loop_id->replaceOperandWith(0, loop_id);
137 back_branch->setMetadata(llvm::LLVMContext::MD_loop, loop_id);
138 }
139
140 // Re-point the IR builder to the loop exit block.
141 b->SetInsertPoint(exit_bb_);
142 }
143
GetLoopMetadata(llvm::IRBuilder<> * b)144 std::vector<llvm::Metadata*> ForLoop::GetLoopMetadata(llvm::IRBuilder<>* b) {
145 const char* const kLlvmLoopUnrollDisableMDName = "llvm.loop.unroll.disable";
146 const char* const kLlvmLoopUnrollFullMDName = "llvm.loop.unroll.full";
147 const char* const kLlvmLoopVectorizeMDName = "llvm.loop.vectorize.enable";
148 llvm::LLVMContext* ctx = &start_index_->getContext();
149
150 std::vector<llvm::Metadata*> result;
151 if (unroll_mode_ == xla::llvm_ir::UnrollMode::kNoUnroll) {
152 result.push_back(llvm::MDNode::get(
153 *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollDisableMDName)}));
154 }
155
156 if (prevent_vectorization_) {
157 result.push_back(llvm::MDNode::get(
158 *ctx, {llvm::MDString::get(*ctx, kLlvmLoopVectorizeMDName),
159 llvm::ConstantAsMetadata::get(b->getFalse())}));
160 }
161
162 if (unroll_mode_ == xla::llvm_ir::UnrollMode::kFullyUnroll) {
163 result.push_back(llvm::MDNode::get(
164 *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollFullMDName)}));
165 }
166 return result;
167 }
168
GetQualifiedName(absl::string_view name)169 std::string ForLoop::GetQualifiedName(absl::string_view name) {
170 return llvm_ir::IrName(prefix_, llvm_ir::IrName(name, suffix_));
171 }
172
CreateLoopBB(absl::string_view name,llvm::IRBuilder<> * b)173 llvm::BasicBlock* ForLoop::CreateLoopBB(absl::string_view name,
174 llvm::IRBuilder<>* b) {
175 return CreateBasicBlock(insert_before_bb_, GetQualifiedName(name), b);
176 }
177
AddLoop(absl::string_view suffix,llvm::Value * start_index,llvm::Value * end_index,UnrollMode unroll_mode,bool prevent_vectorization)178 std::unique_ptr<ForLoop> ForLoopNest::AddLoop(absl::string_view suffix,
179 llvm::Value* start_index,
180 llvm::Value* end_index,
181 UnrollMode unroll_mode,
182 bool prevent_vectorization) {
183 return AddLoop(suffix, start_index, end_index, GetConstantWithIndexType(1),
184 unroll_mode, prevent_vectorization);
185 }
186
AddLoop(absl::string_view suffix,llvm::Value * start_index,llvm::Value * end_index,llvm::Value * stride,UnrollMode unroll_mode,bool prevent_vectorization)187 std::unique_ptr<ForLoop> ForLoopNest::AddLoop(
188 absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index,
189 llvm::Value* stride, UnrollMode unroll_mode, bool prevent_vectorization) {
190 if (inner_loop_body_bb_ != nullptr) {
191 // Create this loop inside the previous one.
192 b_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt());
193 }
194 std::unique_ptr<ForLoop> loop(new ForLoop(
195 /*prefix=*/name_, suffix, start_index, end_index, stride, unroll_mode,
196 prevent_vectorization));
197 loop->Emit(b_);
198
199 if (outer_loop_preheader_bb_ == nullptr) {
200 outer_loop_preheader_bb_ = loop->GetPreheaderBasicBlock();
201 }
202
203 if (outer_loop_exit_bb_ == nullptr) {
204 outer_loop_exit_bb_ = loop->GetExitBasicBlock();
205 }
206
207 inner_loop_body_bb_ = loop->GetBodyBasicBlock();
208
209 return loop;
210 }
211
AddLoop(int64_t start_index,int64_t end_index,absl::string_view suffix,UnrollMode unroll_mode,bool prevent_vectorization)212 std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64_t start_index,
213 int64_t end_index,
214 absl::string_view suffix,
215 UnrollMode unroll_mode,
216 bool prevent_vectorization) {
217 CHECK_LE(start_index, end_index);
218 return AddLoop(suffix, GetConstantWithIndexType(start_index),
219 GetConstantWithIndexType(end_index), unroll_mode,
220 prevent_vectorization);
221 }
222
AddLoop(int64_t start_index,int64_t end_index,int64_t stride,absl::string_view suffix,UnrollMode unroll_mode,bool prevent_vectorization)223 std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64_t start_index,
224 int64_t end_index, int64_t stride,
225 absl::string_view suffix,
226 UnrollMode unroll_mode,
227 bool prevent_vectorization) {
228 CHECK_LE(start_index, end_index);
229 return AddLoop(suffix, GetConstantWithIndexType(start_index),
230 GetConstantWithIndexType(end_index),
231 GetConstantWithIndexType(stride), unroll_mode,
232 prevent_vectorization);
233 }
234
AddLoopsForShape(const Shape & shape,absl::string_view suffix)235 IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape,
236 absl::string_view suffix) {
237 std::vector<int64_t> dimensions(shape.rank());
238 std::iota(dimensions.begin(), dimensions.end(), 0);
239 return IrArray::Index(AddLoopsForShapeOnDimensions(shape, dimensions, suffix),
240 shape, index_type_);
241 }
242
AddLoopsForShapeOnDimensions(const Shape & shape,absl::Span<const int64_t> dimensions,absl::string_view suffix)243 std::vector<llvm::Value*> ForLoopNest::AddLoopsForShapeOnDimensions(
244 const Shape& shape, absl::Span<const int64_t> dimensions,
245 absl::string_view suffix) {
246 std::vector<llvm::Value*> multi_index(shape.dimensions_size());
247 for (int64_t dimension : dimensions) {
248 std::unique_ptr<llvm_ir::ForLoop> loop = AddLoop(
249 /*start_index=*/0,
250 /*end_index=*/shape.dimensions(dimension),
251 /*suffix=*/
252 llvm_ir::IrName(suffix, absl::StrCat(dimension)));
253 multi_index[dimension] = loop->GetIndVarValue();
254 }
255 return multi_index;
256 }
257
EmitOperandArrayLoopNest(const llvm_ir::IrArray & operand_array,int64_t dimension_to_skip,absl::string_view name_suffix)258 std::vector<llvm::Value*> ForLoopNest::EmitOperandArrayLoopNest(
259 const llvm_ir::IrArray& operand_array, int64_t dimension_to_skip,
260 absl::string_view name_suffix) {
261 // Prepares the dimension list we will use to emit the loop nest. Outermost
262 // loops are added first. Add loops in major-to-minor order, and skip the
263 // 'dimension_to_skip' dimension.
264 std::vector<int64_t> dimensions;
265 const Shape& shape = operand_array.GetShape();
266 // Initially get the dimensions in minor to major order, then reverse them.
267 for (int64_t dimension : LayoutUtil::MinorToMajor(shape)) {
268 if (dimension != dimension_to_skip) {
269 dimensions.push_back(dimension);
270 }
271 }
272 absl::c_reverse(dimensions);
273
274 // Create loop nest with one for-loop for each dimension of the
275 // output.
276 std::vector<llvm::Value*> multi_index =
277 AddLoopsForShapeOnDimensions(shape, dimensions, name_suffix);
278 // Verify every dimension except the 'dimension_to_skip' dimension was set in
279 // the index.
280 for (size_t dimension = 0; dimension < multi_index.size(); ++dimension) {
281 if (dimension == dimension_to_skip) {
282 DCHECK_EQ(nullptr, multi_index[dimension]);
283 } else {
284 DCHECK_NE(nullptr, multi_index[dimension]);
285 }
286 }
287 return multi_index;
288 }
289
290 } // namespace llvm_ir
291 } // namespace xla
292