1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
17
18 #include <numeric>
19 #include <vector>
20
21 #include "absl/algorithm/container.h"
22 #include "absl/strings/str_cat.h"
23 #include "llvm/IR/Constants.h"
24 #include "llvm/IR/Function.h"
25 #include "llvm/IR/Instructions.h"
26 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/core/platform/logging.h"
30
31 namespace xla {
32 namespace llvm_ir {
33
ForLoop(absl::string_view prefix,absl::string_view suffix,llvm::Value * start_index,llvm::Value * end_index,llvm::Value * step,UnrollMode unroll_mode,bool prevent_vectorization)34 ForLoop::ForLoop(absl::string_view prefix, absl::string_view suffix,
35 llvm::Value* start_index, llvm::Value* end_index,
36 llvm::Value* step, UnrollMode unroll_mode,
37 bool prevent_vectorization)
38 : prefix_(prefix),
39 suffix_(suffix),
40 start_index_(start_index),
41 end_index_(end_index),
42 step_(step),
43 insert_before_bb_(nullptr),
44 unroll_mode_(unroll_mode),
45 prevent_vectorization_(prevent_vectorization) {}
46
EmitForLoop(absl::string_view prefix,llvm::Value * start_index,llvm::Value * end_index,llvm::Value * step,llvm::IRBuilder<> * b,UnrollMode unroll_mode,bool prevent_vectorization)47 /* static */ std::unique_ptr<ForLoop> ForLoop::EmitForLoop(
48 absl::string_view prefix, llvm::Value* start_index, llvm::Value* end_index,
49 llvm::Value* step, llvm::IRBuilder<>* b, UnrollMode unroll_mode,
50 bool prevent_vectorization) {
51 std::unique_ptr<ForLoop> loop(new ForLoop(prefix, /*suffix=*/"", start_index,
52 end_index, step, unroll_mode,
53 prevent_vectorization));
54 loop->Emit(b);
55 return loop;
56 }
57
Emit(llvm::IRBuilder<> * b)58 void ForLoop::Emit(llvm::IRBuilder<>* b) {
59 // The preheader block is the block the builder is currently emitting
60 // code into.
61 preheader_bb_ = b->GetInsertBlock();
62
63 llvm::BasicBlock::iterator insert_point = b->GetInsertPoint();
64 if (insert_point == preheader_bb_->end()) {
65 // We're emitting the loop at the end of a basic block. Verify there is no
66 // terminator (eg, branch) in the basic block.
67 CHECK_EQ(nullptr, preheader_bb_->getTerminator());
68
69 exit_bb_ = CreateLoopBB("loop_exit", b);
70 } else {
71 // We're emitting the loop into the middle of a basic block. splitBasicBlock
72 // requires that this basic block be well-formed (have a terminator).
73 CHECK_NE(nullptr, preheader_bb_->getTerminator());
74
75 // Split the preheader to create an exit basic block. The exit basic block
76 // will contain all instructions at or after insert_point.
77 exit_bb_ = preheader_bb_->splitBasicBlock(insert_point,
78 GetQualifiedName("loop_exit"));
79
80 // splitBasicBlock adds an unconditional branch between the split basic
81 // blocks. Remove it. An unconditional branch will be added below from the
82 // preheader to the header.
83 preheader_bb_->getTerminator()->eraseFromParent();
84 }
85 insert_before_bb_ = exit_bb_;
86
87 // Create remaining basic block which form the inside of the loop.
88 header_bb_ = CreateLoopBB("loop_header", b);
89 body_bb_ = CreateLoopBB("loop_body", b);
90
91 // Function entry basic block.
92 // Emit alloca for the induction variable. We do this at the entry to the
93 // basic block to ensure the alloc only executes once per function (we could
94 // be emitting a nested loop).
95 llvm::Function* func = preheader_bb_->getParent();
96 b->SetInsertPoint(&func->getEntryBlock(),
97 func->getEntryBlock().getFirstInsertionPt());
98 llvm::Value* indvar_address = b->CreateAlloca(
99 start_index_->getType(), nullptr, GetQualifiedName("invar_address"));
100
101 // Preheader basic block.
102 // Initialize induction variable starting index. Create branch to the header.
103 b->SetInsertPoint(preheader_bb_);
104 b->CreateStore(start_index_, indvar_address);
105 // The preheader should not have a branch yet.
106 CHECK_EQ(preheader_bb_->getTerminator(), nullptr);
107 b->CreateBr(header_bb_);
108
109 // Header basic block.
110 // Emit the loop conditional branch. Load and compare indvar with ending
111 // index and jump to loop exit if equal. Jump to body otherwise.
112 b->SetInsertPoint(header_bb_);
113 indvar_ = b->CreateLoad(indvar_address, GetQualifiedName("indvar"));
114 llvm::Value* exit_cond = b->CreateICmpUGE(indvar_, end_index_);
115 b->CreateCondBr(/*Cond=*/exit_cond,
116 /*True=*/exit_bb_, /*False=*/body_bb_);
117
118 // Body basic block.
119 // Increment indvar, store indvar, and jump to header.
120 b->SetInsertPoint(body_bb_);
121 llvm::Value* step = step_;
122 llvm::Value* indvar = indvar_;
123
124 llvm::Value* indvar_inc = b->CreateAdd(indvar, step, "invar.inc",
125 /*HasNUW=*/true, /*HasNSW=*/true);
126 b->CreateStore(indvar_inc, indvar_address);
127 llvm::BranchInst* back_branch = b->CreateBr(header_bb_);
128
129 std::vector<llvm::Metadata*> loop_metadata = GetLoopMetadata(b);
130 if (!loop_metadata.empty()) {
131 llvm::LLVMContext* ctx = &start_index_->getContext();
132 auto temp_node = llvm::MDNode::getTemporary(*ctx, llvm::None);
133 loop_metadata.insert(loop_metadata.begin(), temp_node.get());
134 auto loop_id = llvm::MDNode::get(*ctx, loop_metadata);
135 loop_id->replaceOperandWith(0, loop_id);
136 back_branch->setMetadata(llvm::LLVMContext::MD_loop, loop_id);
137 }
138
139 // Re-point the IR builder to the loop exit block.
140 b->SetInsertPoint(exit_bb_);
141 }
142
GetLoopMetadata(llvm::IRBuilder<> * b)143 std::vector<llvm::Metadata*> ForLoop::GetLoopMetadata(llvm::IRBuilder<>* b) {
144 const char* const kLlvmLoopUnrollDisableMDName = "llvm.loop.unroll.disable";
145 const char* const kLlvmLoopUnrollFullMDName = "llvm.loop.unroll.full";
146 const char* const kLlvmLoopVectorizeMDName = "llvm.loop.vectorize.enable";
147 llvm::LLVMContext* ctx = &start_index_->getContext();
148
149 std::vector<llvm::Metadata*> result;
150 if (unroll_mode_ == xla::llvm_ir::UnrollMode::kNoUnroll) {
151 result.push_back(llvm::MDNode::get(
152 *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollDisableMDName)}));
153 }
154
155 if (prevent_vectorization_) {
156 result.push_back(llvm::MDNode::get(
157 *ctx, {llvm::MDString::get(*ctx, kLlvmLoopVectorizeMDName),
158 llvm::ConstantAsMetadata::get(b->getFalse())}));
159 }
160
161 if (unroll_mode_ == xla::llvm_ir::UnrollMode::kFullyUnroll) {
162 result.push_back(llvm::MDNode::get(
163 *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollFullMDName)}));
164 }
165 return result;
166 }
167
GetQualifiedName(absl::string_view name)168 string ForLoop::GetQualifiedName(absl::string_view name) {
169 return llvm_ir::IrName(prefix_, llvm_ir::IrName(name, suffix_));
170 }
171
CreateLoopBB(absl::string_view name,llvm::IRBuilder<> * b)172 llvm::BasicBlock* ForLoop::CreateLoopBB(absl::string_view name,
173 llvm::IRBuilder<>* b) {
174 return CreateBasicBlock(insert_before_bb_, GetQualifiedName(name), b);
175 }
176
AddLoop(absl::string_view suffix,llvm::Value * start_index,llvm::Value * end_index,UnrollMode unroll_mode,bool prevent_vectorization)177 std::unique_ptr<ForLoop> ForLoopNest::AddLoop(absl::string_view suffix,
178 llvm::Value* start_index,
179 llvm::Value* end_index,
180 UnrollMode unroll_mode,
181 bool prevent_vectorization) {
182 return AddLoop(suffix, start_index, end_index, GetConstantWithIndexType(1),
183 unroll_mode, prevent_vectorization);
184 }
185
AddLoop(absl::string_view suffix,llvm::Value * start_index,llvm::Value * end_index,llvm::Value * stride,UnrollMode unroll_mode,bool prevent_vectorization)186 std::unique_ptr<ForLoop> ForLoopNest::AddLoop(
187 absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index,
188 llvm::Value* stride, UnrollMode unroll_mode, bool prevent_vectorization) {
189 if (inner_loop_body_bb_ != nullptr) {
190 // Create this loop inside the previous one.
191 b_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt());
192 }
193 std::unique_ptr<ForLoop> loop(new ForLoop(
194 /*prefix=*/name_, suffix, start_index, end_index, stride, unroll_mode,
195 prevent_vectorization));
196 loop->Emit(b_);
197
198 if (outer_loop_preheader_bb_ == nullptr) {
199 outer_loop_preheader_bb_ = loop->GetPreheaderBasicBlock();
200 }
201
202 if (outer_loop_exit_bb_ == nullptr) {
203 outer_loop_exit_bb_ = loop->GetExitBasicBlock();
204 }
205
206 inner_loop_body_bb_ = loop->GetBodyBasicBlock();
207
208 return loop;
209 }
210
AddLoop(int64 start_index,int64 end_index,absl::string_view suffix,UnrollMode unroll_mode,bool prevent_vectorization)211 std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
212 int64 end_index,
213 absl::string_view suffix,
214 UnrollMode unroll_mode,
215 bool prevent_vectorization) {
216 CHECK_LE(start_index, end_index);
217 return AddLoop(suffix, GetConstantWithIndexType(start_index),
218 GetConstantWithIndexType(end_index), unroll_mode,
219 prevent_vectorization);
220 }
221
AddLoop(int64 start_index,int64 end_index,int64 stride,absl::string_view suffix,UnrollMode unroll_mode,bool prevent_vectorization)222 std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
223 int64 end_index, int64 stride,
224 absl::string_view suffix,
225 UnrollMode unroll_mode,
226 bool prevent_vectorization) {
227 CHECK_LE(start_index, end_index);
228 return AddLoop(suffix, GetConstantWithIndexType(start_index),
229 GetConstantWithIndexType(end_index),
230 GetConstantWithIndexType(stride), unroll_mode,
231 prevent_vectorization);
232 }
233
AddLoopsForShape(const Shape & shape,absl::string_view suffix)234 IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape,
235 absl::string_view suffix) {
236 std::vector<int64> dimensions(shape.rank());
237 std::iota(dimensions.begin(), dimensions.end(), 0);
238 return IrArray::Index(AddLoopsForShapeOnDimensions(shape, dimensions, suffix),
239 shape, index_type_);
240 }
241
AddLoopsForShapeOnDimensions(const Shape & shape,absl::Span<const int64> dimensions,absl::string_view suffix)242 std::vector<llvm::Value*> ForLoopNest::AddLoopsForShapeOnDimensions(
243 const Shape& shape, absl::Span<const int64> dimensions,
244 absl::string_view suffix) {
245 std::vector<llvm::Value*> multi_index(shape.dimensions_size());
246 for (int64 dimension : dimensions) {
247 std::unique_ptr<llvm_ir::ForLoop> loop = AddLoop(
248 /*start_index=*/0,
249 /*end_index=*/shape.dimensions(dimension),
250 /*suffix=*/
251 llvm_ir::IrName(suffix, absl::StrCat(dimension)));
252 multi_index[dimension] = loop->GetIndVarValue();
253 }
254 return multi_index;
255 }
256
EmitOperandArrayLoopNest(const llvm_ir::IrArray & operand_array,int64 dimension_to_skip,absl::string_view name_suffix)257 std::vector<llvm::Value*> ForLoopNest::EmitOperandArrayLoopNest(
258 const llvm_ir::IrArray& operand_array, int64 dimension_to_skip,
259 absl::string_view name_suffix) {
260 // Prepares the dimension list we will use to emit the loop nest. Outermost
261 // loops are added first. Add loops in major-to-minor order, and skip the
262 // 'dimension_to_skip' dimension.
263 std::vector<int64> dimensions;
264 const Shape& shape = operand_array.GetShape();
265 // Initially get the dimensions in minor to major order, then reverse them.
266 for (int64 dimension : LayoutUtil::MinorToMajor(shape)) {
267 if (dimension != dimension_to_skip) {
268 dimensions.push_back(dimension);
269 }
270 }
271 absl::c_reverse(dimensions);
272
273 // Create loop nest with one for-loop for each dimension of the
274 // output.
275 std::vector<llvm::Value*> multi_index =
276 AddLoopsForShapeOnDimensions(shape, dimensions, name_suffix);
277 // Verify every dimension except the 'dimension_to_skip' dimension was set in
278 // the index.
279 for (size_t dimension = 0; dimension < multi_index.size(); ++dimension) {
280 if (dimension == dimension_to_skip) {
281 DCHECK_EQ(nullptr, multi_index[dimension]);
282 } else {
283 DCHECK_NE(nullptr, multi_index[dimension]);
284 }
285 }
286 return multi_index;
287 }
288
289 } // namespace llvm_ir
290 } // namespace xla
291