• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
17 
18 #include <numeric>
19 #include <vector>
20 
21 #include "absl/strings/str_cat.h"
22 #include "llvm/IR/Constants.h"
23 #include "llvm/IR/Function.h"
24 #include "llvm/IR/Instructions.h"
25 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
26 #include "tensorflow/compiler/xla/shape_util.h"
27 #include "tensorflow/compiler/xla/types.h"
28 #include "tensorflow/core/platform/logging.h"
29 
30 namespace xla {
31 namespace llvm_ir {
32 
ForLoop(absl::string_view prefix,absl::string_view suffix,llvm::Value * start_index,llvm::Value * end_index,llvm::Value * step,UnrollMode unroll_mode,bool prevent_vectorization)33 ForLoop::ForLoop(absl::string_view prefix, absl::string_view suffix,
34                  llvm::Value* start_index, llvm::Value* end_index,
35                  llvm::Value* step, UnrollMode unroll_mode,
36                  bool prevent_vectorization)
37     : prefix_(prefix),
38       suffix_(suffix),
39       start_index_(start_index),
40       end_index_(end_index),
41       step_(step),
42       insert_before_bb_(nullptr),
43       unroll_mode_(unroll_mode),
44       prevent_vectorization_(prevent_vectorization) {}
45 
EmitForLoop(absl::string_view prefix,llvm::Value * start_index,llvm::Value * end_index,llvm::Value * step,llvm::IRBuilder<> * b,UnrollMode unroll_mode,bool prevent_vectorization)46 /* static */ std::unique_ptr<ForLoop> ForLoop::EmitForLoop(
47     absl::string_view prefix, llvm::Value* start_index, llvm::Value* end_index,
48     llvm::Value* step, llvm::IRBuilder<>* b, UnrollMode unroll_mode,
49     bool prevent_vectorization) {
50   std::unique_ptr<ForLoop> loop(new ForLoop(prefix, /*suffix=*/"", start_index,
51                                             end_index, step, unroll_mode,
52                                             prevent_vectorization));
53   loop->Emit(b);
54   return loop;
55 }
56 
Emit(llvm::IRBuilder<> * b)57 void ForLoop::Emit(llvm::IRBuilder<>* b) {
58   // The preheader block is the block the builder is currently emitting
59   // code into.
60   preheader_bb_ = b->GetInsertBlock();
61 
62   llvm::BasicBlock::iterator insert_point = b->GetInsertPoint();
63   if (insert_point == preheader_bb_->end()) {
64     // We're emitting the loop at the end of a basic block. Verify there is no
65     // terminator (eg, branch) in the basic block.
66     CHECK_EQ(nullptr, preheader_bb_->getTerminator());
67 
68     exit_bb_ = CreateLoopBB("loop_exit", b);
69   } else {
70     // We're emitting the loop into the middle of a basic block. splitBasicBlock
71     // requires that this basic block be well-formed (have a terminator).
72     CHECK_NE(nullptr, preheader_bb_->getTerminator());
73 
74     // Split the preheader to create an exit basic block. The exit basic block
75     // will contain all instructions at or after insert_point.
76     exit_bb_ = preheader_bb_->splitBasicBlock(insert_point,
77                                               GetQualifiedName("loop_exit"));
78 
79     // splitBasicBlock adds an unconditional branch between the split basic
80     // blocks. Remove it. An unconditional branch will be added below from the
81     // preheader to the header.
82     preheader_bb_->getTerminator()->eraseFromParent();
83   }
84   insert_before_bb_ = exit_bb_;
85 
86   // Create remaining basic block which form the inside of the loop.
87   header_bb_ = CreateLoopBB("loop_header", b);
88   body_bb_ = CreateLoopBB("loop_body", b);
89 
90   // Function entry basic block.
91   // Emit alloca for the induction variable. We do this at the entry to the
92   // basic block to ensure the alloc only executes once per function (we could
93   // be emitting a nested loop).
94   llvm::Function* func = preheader_bb_->getParent();
95   b->SetInsertPoint(&func->getEntryBlock(),
96                     func->getEntryBlock().getFirstInsertionPt());
97   llvm::Value* indvar_address = b->CreateAlloca(
98       start_index_->getType(), nullptr, GetQualifiedName("invar_address"));
99 
100   // Preheader basic block.
101   // Initialize induction variable starting index. Create branch to the header.
102   b->SetInsertPoint(preheader_bb_);
103   b->CreateStore(start_index_, indvar_address);
104   // The preheader should not have a branch yet.
105   CHECK_EQ(preheader_bb_->getTerminator(), nullptr);
106   b->CreateBr(header_bb_);
107 
108   // Header basic block.
109   // Emit the loop conditional branch. Load and compare indvar with ending
110   // index and jump to loop exit if equal. Jump to body otherwise.
111   b->SetInsertPoint(header_bb_);
112   indvar_ = b->CreateLoad(indvar_address, GetQualifiedName("indvar"));
113   llvm::Value* exit_cond = b->CreateICmpUGE(indvar_, end_index_);
114   b->CreateCondBr(/*Cond=*/exit_cond,
115                   /*True=*/exit_bb_, /*False=*/body_bb_);
116 
117   // Body basic block.
118   // Increment indvar, store indvar, and jump to header.
119   b->SetInsertPoint(body_bb_);
120   llvm::Value* step = step_;
121   llvm::Value* indvar = indvar_;
122 
123   llvm::Value* indvar_inc = b->CreateAdd(indvar, step, "invar.inc",
124                                          /*HasNUW=*/true, /*HasNSW=*/true);
125   b->CreateStore(indvar_inc, indvar_address);
126   llvm::BranchInst* back_branch = b->CreateBr(header_bb_);
127 
128   std::vector<llvm::Metadata*> loop_metadata = GetLoopMetadata(b);
129   if (!loop_metadata.empty()) {
130     llvm::LLVMContext* ctx = &start_index_->getContext();
131     auto temp_node = llvm::MDNode::getTemporary(*ctx, llvm::None);
132     loop_metadata.insert(loop_metadata.begin(), temp_node.get());
133     auto loop_id = llvm::MDNode::get(*ctx, loop_metadata);
134     loop_id->replaceOperandWith(0, loop_id);
135     back_branch->setMetadata(llvm::LLVMContext::MD_loop, loop_id);
136   }
137 
138   // Re-point the IR builder to the loop exit block.
139   b->SetInsertPoint(exit_bb_);
140 }
141 
GetLoopMetadata(llvm::IRBuilder<> * b)142 std::vector<llvm::Metadata*> ForLoop::GetLoopMetadata(llvm::IRBuilder<>* b) {
143   const char* const kLlvmLoopUnrollDisableMDName = "llvm.loop.unroll.disable";
144   const char* const kLlvmLoopUnrollFullMDName = "llvm.loop.unroll.full";
145   const char* const kLlvmLoopVectorizeMDName = "llvm.loop.vectorize.enable";
146   llvm::LLVMContext* ctx = &start_index_->getContext();
147 
148   std::vector<llvm::Metadata*> result;
149   if (unroll_mode_ == xla::llvm_ir::UnrollMode::kNoUnroll) {
150     result.push_back(llvm::MDNode::get(
151         *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollDisableMDName)}));
152   }
153 
154   if (prevent_vectorization_) {
155     result.push_back(llvm::MDNode::get(
156         *ctx, {llvm::MDString::get(*ctx, kLlvmLoopVectorizeMDName),
157                llvm::ConstantAsMetadata::get(b->getFalse())}));
158   }
159 
160   if (unroll_mode_ == xla::llvm_ir::UnrollMode::kFullyUnroll) {
161     result.push_back(llvm::MDNode::get(
162         *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollFullMDName)}));
163   }
164   return result;
165 }
166 
GetQualifiedName(absl::string_view name)167 string ForLoop::GetQualifiedName(absl::string_view name) {
168   return llvm_ir::IrName(prefix_, llvm_ir::IrName(name, suffix_));
169 }
170 
CreateLoopBB(absl::string_view name,llvm::IRBuilder<> * b)171 llvm::BasicBlock* ForLoop::CreateLoopBB(absl::string_view name,
172                                         llvm::IRBuilder<>* b) {
173   return CreateBasicBlock(insert_before_bb_, GetQualifiedName(name), b);
174 }
175 
AddLoop(absl::string_view suffix,llvm::Value * start_index,llvm::Value * end_index,UnrollMode unroll_mode,bool prevent_vectorization)176 std::unique_ptr<ForLoop> ForLoopNest::AddLoop(absl::string_view suffix,
177                                               llvm::Value* start_index,
178                                               llvm::Value* end_index,
179                                               UnrollMode unroll_mode,
180                                               bool prevent_vectorization) {
181   return AddLoop(suffix, start_index, end_index, GetConstantWithIndexType(1),
182                  unroll_mode, prevent_vectorization);
183 }
184 
AddLoop(absl::string_view suffix,llvm::Value * start_index,llvm::Value * end_index,llvm::Value * stride,UnrollMode unroll_mode,bool prevent_vectorization)185 std::unique_ptr<ForLoop> ForLoopNest::AddLoop(
186     absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index,
187     llvm::Value* stride, UnrollMode unroll_mode, bool prevent_vectorization) {
188   if (inner_loop_body_bb_ != nullptr) {
189     // Create this loop inside the previous one.
190     b_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt());
191   }
192   std::unique_ptr<ForLoop> loop(new ForLoop(
193       /*prefix=*/name_, suffix, start_index, end_index, stride, unroll_mode,
194       prevent_vectorization));
195   loop->Emit(b_);
196 
197   if (outer_loop_preheader_bb_ == nullptr) {
198     outer_loop_preheader_bb_ = loop->GetPreheaderBasicBlock();
199   }
200 
201   if (outer_loop_exit_bb_ == nullptr) {
202     outer_loop_exit_bb_ = loop->GetExitBasicBlock();
203   }
204 
205   inner_loop_body_bb_ = loop->GetBodyBasicBlock();
206 
207   return loop;
208 }
209 
AddLoop(int64 start_index,int64 end_index,absl::string_view suffix,UnrollMode unroll_mode,bool prevent_vectorization)210 std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
211                                               int64 end_index,
212                                               absl::string_view suffix,
213                                               UnrollMode unroll_mode,
214                                               bool prevent_vectorization) {
215   CHECK_LE(start_index, end_index);
216   return AddLoop(suffix, GetConstantWithIndexType(start_index),
217                  GetConstantWithIndexType(end_index), unroll_mode,
218                  prevent_vectorization);
219 }
220 
AddLoop(int64 start_index,int64 end_index,int64 stride,absl::string_view suffix,UnrollMode unroll_mode,bool prevent_vectorization)221 std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
222                                               int64 end_index, int64 stride,
223                                               absl::string_view suffix,
224                                               UnrollMode unroll_mode,
225                                               bool prevent_vectorization) {
226   CHECK_LE(start_index, end_index);
227   return AddLoop(suffix, GetConstantWithIndexType(start_index),
228                  GetConstantWithIndexType(end_index),
229                  GetConstantWithIndexType(stride), unroll_mode,
230                  prevent_vectorization);
231 }
232 
AddLoopsForShape(const Shape & shape,absl::string_view suffix)233 IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape,
234                                              absl::string_view suffix) {
235   std::vector<int64> dimensions(shape.rank());
236   std::iota(dimensions.begin(), dimensions.end(), 0);
237   return IrArray::Index(AddLoopsForShapeOnDimensions(shape, dimensions, suffix),
238                         shape, index_type_);
239 }
240 
AddLoopsForShapeOnDimensions(const Shape & shape,absl::Span<const int64> dimensions,absl::string_view suffix)241 std::vector<llvm::Value*> ForLoopNest::AddLoopsForShapeOnDimensions(
242     const Shape& shape, absl::Span<const int64> dimensions,
243     absl::string_view suffix) {
244   std::vector<llvm::Value*> multi_index(shape.dimensions_size());
245   for (int64 dimension : dimensions) {
246     std::unique_ptr<llvm_ir::ForLoop> loop = AddLoop(
247         /*start_index=*/0,
248         /*end_index=*/shape.dimensions(dimension),
249         /*suffix=*/
250         llvm_ir::IrName(suffix, absl::StrCat(dimension)));
251     multi_index[dimension] = loop->GetIndVarValue();
252   }
253   return multi_index;
254 }
255 
EmitOperandArrayLoopNest(const llvm_ir::IrArray & operand_array,int64 dimension_to_skip,absl::string_view name_suffix)256 std::vector<llvm::Value*> ForLoopNest::EmitOperandArrayLoopNest(
257     const llvm_ir::IrArray& operand_array, int64 dimension_to_skip,
258     absl::string_view name_suffix) {
259   // Prepares the dimension list we will use to emit the loop nest. Outermost
260   // loops are added first. Add loops in major-to-minor order, and skip the
261   // 'dimension_to_skip' dimension.
262   std::vector<int64> dimensions;
263   const Shape& shape = operand_array.GetShape();
264   for (int64 dimension : LayoutUtil::MinorToMajor(shape)) {
265     if (dimension != dimension_to_skip) {
266       dimensions.push_back(dimension);
267     }
268   }
269 
270   // Create loop nest with one for-loop for each dimension of the
271   // output.
272   std::vector<llvm::Value*> multi_index =
273       AddLoopsForShapeOnDimensions(shape, dimensions, name_suffix);
274   // Verify every dimension except the 'dimension_to_skip' dimension was set in
275   // the index.
276   for (size_t dimension = 0; dimension < multi_index.size(); ++dimension) {
277     if (dimension == dimension_to_skip) {
278       DCHECK_EQ(nullptr, multi_index[dimension]);
279     } else {
280       DCHECK_NE(nullptr, multi_index[dimension]);
281     }
282   }
283   return multi_index;
284 }
285 
286 }  // namespace llvm_ir
287 }  // namespace xla
288