• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
17 
18 #include <numeric>
19 #include <vector>
20 
21 #include "absl/algorithm/container.h"
22 #include "absl/strings/str_cat.h"
23 #include "llvm/IR/Constants.h"
24 #include "llvm/IR/Function.h"
25 #include "llvm/IR/Instructions.h"
26 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/core/platform/logging.h"
30 
31 namespace xla {
32 namespace llvm_ir {
33 
ForLoop(absl::string_view prefix,absl::string_view suffix,llvm::Value * start_index,llvm::Value * end_index,llvm::Value * step,UnrollMode unroll_mode,bool prevent_vectorization)34 ForLoop::ForLoop(absl::string_view prefix, absl::string_view suffix,
35                  llvm::Value* start_index, llvm::Value* end_index,
36                  llvm::Value* step, UnrollMode unroll_mode,
37                  bool prevent_vectorization)
38     : prefix_(prefix),
39       suffix_(suffix),
40       start_index_(start_index),
41       end_index_(end_index),
42       step_(step),
43       insert_before_bb_(nullptr),
44       unroll_mode_(unroll_mode),
45       prevent_vectorization_(prevent_vectorization) {}
46 
EmitForLoop(absl::string_view prefix,llvm::Value * start_index,llvm::Value * end_index,llvm::Value * step,llvm::IRBuilder<> * b,UnrollMode unroll_mode,bool prevent_vectorization)47 /* static */ std::unique_ptr<ForLoop> ForLoop::EmitForLoop(
48     absl::string_view prefix, llvm::Value* start_index, llvm::Value* end_index,
49     llvm::Value* step, llvm::IRBuilder<>* b, UnrollMode unroll_mode,
50     bool prevent_vectorization) {
51   std::unique_ptr<ForLoop> loop(new ForLoop(prefix, /*suffix=*/"", start_index,
52                                             end_index, step, unroll_mode,
53                                             prevent_vectorization));
54   loop->Emit(b);
55   return loop;
56 }
57 
Emit(llvm::IRBuilder<> * b)58 void ForLoop::Emit(llvm::IRBuilder<>* b) {
59   // The preheader block is the block the builder is currently emitting
60   // code into.
61   preheader_bb_ = b->GetInsertBlock();
62 
63   llvm::BasicBlock::iterator insert_point = b->GetInsertPoint();
64   if (insert_point == preheader_bb_->end()) {
65     // We're emitting the loop at the end of a basic block. Verify there is no
66     // terminator (eg, branch) in the basic block.
67     CHECK_EQ(nullptr, preheader_bb_->getTerminator());
68 
69     exit_bb_ = CreateLoopBB("loop_exit", b);
70   } else {
71     // We're emitting the loop into the middle of a basic block. splitBasicBlock
72     // requires that this basic block be well-formed (have a terminator).
73     CHECK_NE(nullptr, preheader_bb_->getTerminator());
74 
75     // Split the preheader to create an exit basic block. The exit basic block
76     // will contain all instructions at or after insert_point.
77     exit_bb_ = preheader_bb_->splitBasicBlock(insert_point,
78                                               GetQualifiedName("loop_exit"));
79 
80     // splitBasicBlock adds an unconditional branch between the split basic
81     // blocks. Remove it. An unconditional branch will be added below from the
82     // preheader to the header.
83     preheader_bb_->getTerminator()->eraseFromParent();
84   }
85   insert_before_bb_ = exit_bb_;
86 
87   // Create remaining basic block which form the inside of the loop.
88   header_bb_ = CreateLoopBB("loop_header", b);
89   body_bb_ = CreateLoopBB("loop_body", b);
90 
91   // Function entry basic block.
92   // Emit alloca for the induction variable. We do this at the entry to the
93   // basic block to ensure the alloc only executes once per function (we could
94   // be emitting a nested loop).
95   llvm::Function* func = preheader_bb_->getParent();
96   b->SetInsertPoint(&func->getEntryBlock(),
97                     func->getEntryBlock().getFirstInsertionPt());
98   llvm::Value* indvar_address = b->CreateAlloca(
99       start_index_->getType(), nullptr, GetQualifiedName("invar_address"));
100 
101   // Preheader basic block.
102   // Initialize induction variable starting index. Create branch to the header.
103   b->SetInsertPoint(preheader_bb_);
104   b->CreateStore(start_index_, indvar_address);
105   // The preheader should not have a branch yet.
106   CHECK_EQ(preheader_bb_->getTerminator(), nullptr);
107   b->CreateBr(header_bb_);
108 
109   // Header basic block.
110   // Emit the loop conditional branch. Load and compare indvar with ending
111   // index and jump to loop exit if equal. Jump to body otherwise.
112   b->SetInsertPoint(header_bb_);
113   indvar_ = b->CreateLoad(indvar_address, GetQualifiedName("indvar"));
114   llvm::Value* exit_cond = b->CreateICmpUGE(indvar_, end_index_);
115   b->CreateCondBr(/*Cond=*/exit_cond,
116                   /*True=*/exit_bb_, /*False=*/body_bb_);
117 
118   // Body basic block.
119   // Increment indvar, store indvar, and jump to header.
120   b->SetInsertPoint(body_bb_);
121   llvm::Value* step = step_;
122   llvm::Value* indvar = indvar_;
123 
124   llvm::Value* indvar_inc = b->CreateAdd(indvar, step, "invar.inc",
125                                          /*HasNUW=*/true, /*HasNSW=*/true);
126   b->CreateStore(indvar_inc, indvar_address);
127   llvm::BranchInst* back_branch = b->CreateBr(header_bb_);
128 
129   std::vector<llvm::Metadata*> loop_metadata = GetLoopMetadata(b);
130   if (!loop_metadata.empty()) {
131     llvm::LLVMContext* ctx = &start_index_->getContext();
132     auto temp_node = llvm::MDNode::getTemporary(*ctx, llvm::None);
133     loop_metadata.insert(loop_metadata.begin(), temp_node.get());
134     auto loop_id = llvm::MDNode::get(*ctx, loop_metadata);
135     loop_id->replaceOperandWith(0, loop_id);
136     back_branch->setMetadata(llvm::LLVMContext::MD_loop, loop_id);
137   }
138 
139   // Re-point the IR builder to the loop exit block.
140   b->SetInsertPoint(exit_bb_);
141 }
142 
GetLoopMetadata(llvm::IRBuilder<> * b)143 std::vector<llvm::Metadata*> ForLoop::GetLoopMetadata(llvm::IRBuilder<>* b) {
144   const char* const kLlvmLoopUnrollDisableMDName = "llvm.loop.unroll.disable";
145   const char* const kLlvmLoopUnrollFullMDName = "llvm.loop.unroll.full";
146   const char* const kLlvmLoopVectorizeMDName = "llvm.loop.vectorize.enable";
147   llvm::LLVMContext* ctx = &start_index_->getContext();
148 
149   std::vector<llvm::Metadata*> result;
150   if (unroll_mode_ == xla::llvm_ir::UnrollMode::kNoUnroll) {
151     result.push_back(llvm::MDNode::get(
152         *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollDisableMDName)}));
153   }
154 
155   if (prevent_vectorization_) {
156     result.push_back(llvm::MDNode::get(
157         *ctx, {llvm::MDString::get(*ctx, kLlvmLoopVectorizeMDName),
158                llvm::ConstantAsMetadata::get(b->getFalse())}));
159   }
160 
161   if (unroll_mode_ == xla::llvm_ir::UnrollMode::kFullyUnroll) {
162     result.push_back(llvm::MDNode::get(
163         *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollFullMDName)}));
164   }
165   return result;
166 }
167 
GetQualifiedName(absl::string_view name)168 string ForLoop::GetQualifiedName(absl::string_view name) {
169   return llvm_ir::IrName(prefix_, llvm_ir::IrName(name, suffix_));
170 }
171 
CreateLoopBB(absl::string_view name,llvm::IRBuilder<> * b)172 llvm::BasicBlock* ForLoop::CreateLoopBB(absl::string_view name,
173                                         llvm::IRBuilder<>* b) {
174   return CreateBasicBlock(insert_before_bb_, GetQualifiedName(name), b);
175 }
176 
AddLoop(absl::string_view suffix,llvm::Value * start_index,llvm::Value * end_index,UnrollMode unroll_mode,bool prevent_vectorization)177 std::unique_ptr<ForLoop> ForLoopNest::AddLoop(absl::string_view suffix,
178                                               llvm::Value* start_index,
179                                               llvm::Value* end_index,
180                                               UnrollMode unroll_mode,
181                                               bool prevent_vectorization) {
182   return AddLoop(suffix, start_index, end_index, GetConstantWithIndexType(1),
183                  unroll_mode, prevent_vectorization);
184 }
185 
AddLoop(absl::string_view suffix,llvm::Value * start_index,llvm::Value * end_index,llvm::Value * stride,UnrollMode unroll_mode,bool prevent_vectorization)186 std::unique_ptr<ForLoop> ForLoopNest::AddLoop(
187     absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index,
188     llvm::Value* stride, UnrollMode unroll_mode, bool prevent_vectorization) {
189   if (inner_loop_body_bb_ != nullptr) {
190     // Create this loop inside the previous one.
191     b_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt());
192   }
193   std::unique_ptr<ForLoop> loop(new ForLoop(
194       /*prefix=*/name_, suffix, start_index, end_index, stride, unroll_mode,
195       prevent_vectorization));
196   loop->Emit(b_);
197 
198   if (outer_loop_preheader_bb_ == nullptr) {
199     outer_loop_preheader_bb_ = loop->GetPreheaderBasicBlock();
200   }
201 
202   if (outer_loop_exit_bb_ == nullptr) {
203     outer_loop_exit_bb_ = loop->GetExitBasicBlock();
204   }
205 
206   inner_loop_body_bb_ = loop->GetBodyBasicBlock();
207 
208   return loop;
209 }
210 
AddLoop(int64 start_index,int64 end_index,absl::string_view suffix,UnrollMode unroll_mode,bool prevent_vectorization)211 std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
212                                               int64 end_index,
213                                               absl::string_view suffix,
214                                               UnrollMode unroll_mode,
215                                               bool prevent_vectorization) {
216   CHECK_LE(start_index, end_index);
217   return AddLoop(suffix, GetConstantWithIndexType(start_index),
218                  GetConstantWithIndexType(end_index), unroll_mode,
219                  prevent_vectorization);
220 }
221 
AddLoop(int64 start_index,int64 end_index,int64 stride,absl::string_view suffix,UnrollMode unroll_mode,bool prevent_vectorization)222 std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64 start_index,
223                                               int64 end_index, int64 stride,
224                                               absl::string_view suffix,
225                                               UnrollMode unroll_mode,
226                                               bool prevent_vectorization) {
227   CHECK_LE(start_index, end_index);
228   return AddLoop(suffix, GetConstantWithIndexType(start_index),
229                  GetConstantWithIndexType(end_index),
230                  GetConstantWithIndexType(stride), unroll_mode,
231                  prevent_vectorization);
232 }
233 
AddLoopsForShape(const Shape & shape,absl::string_view suffix)234 IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape,
235                                              absl::string_view suffix) {
236   std::vector<int64> dimensions(shape.rank());
237   std::iota(dimensions.begin(), dimensions.end(), 0);
238   return IrArray::Index(AddLoopsForShapeOnDimensions(shape, dimensions, suffix),
239                         shape, index_type_);
240 }
241 
AddLoopsForShapeOnDimensions(const Shape & shape,absl::Span<const int64> dimensions,absl::string_view suffix)242 std::vector<llvm::Value*> ForLoopNest::AddLoopsForShapeOnDimensions(
243     const Shape& shape, absl::Span<const int64> dimensions,
244     absl::string_view suffix) {
245   std::vector<llvm::Value*> multi_index(shape.dimensions_size());
246   for (int64 dimension : dimensions) {
247     std::unique_ptr<llvm_ir::ForLoop> loop = AddLoop(
248         /*start_index=*/0,
249         /*end_index=*/shape.dimensions(dimension),
250         /*suffix=*/
251         llvm_ir::IrName(suffix, absl::StrCat(dimension)));
252     multi_index[dimension] = loop->GetIndVarValue();
253   }
254   return multi_index;
255 }
256 
EmitOperandArrayLoopNest(const llvm_ir::IrArray & operand_array,int64 dimension_to_skip,absl::string_view name_suffix)257 std::vector<llvm::Value*> ForLoopNest::EmitOperandArrayLoopNest(
258     const llvm_ir::IrArray& operand_array, int64 dimension_to_skip,
259     absl::string_view name_suffix) {
260   // Prepares the dimension list we will use to emit the loop nest. Outermost
261   // loops are added first. Add loops in major-to-minor order, and skip the
262   // 'dimension_to_skip' dimension.
263   std::vector<int64> dimensions;
264   const Shape& shape = operand_array.GetShape();
265   // Initially get the dimensions in minor to major order, then reverse them.
266   for (int64 dimension : LayoutUtil::MinorToMajor(shape)) {
267     if (dimension != dimension_to_skip) {
268       dimensions.push_back(dimension);
269     }
270   }
271   absl::c_reverse(dimensions);
272 
273   // Create loop nest with one for-loop for each dimension of the
274   // output.
275   std::vector<llvm::Value*> multi_index =
276       AddLoopsForShapeOnDimensions(shape, dimensions, name_suffix);
277   // Verify every dimension except the 'dimension_to_skip' dimension was set in
278   // the index.
279   for (size_t dimension = 0; dimension < multi_index.size(); ++dimension) {
280     if (dimension == dimension_to_skip) {
281       DCHECK_EQ(nullptr, multi_index[dimension]);
282     } else {
283       DCHECK_NE(nullptr, multi_index[dimension]);
284     }
285   }
286   return multi_index;
287 }
288 
289 }  // namespace llvm_ir
290 }  // namespace xla
291