• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
17 
18 #include <numeric>
19 #include <vector>
20 
21 #include "absl/algorithm/container.h"
22 #include "absl/strings/str_cat.h"
23 #include "llvm/IR/Constants.h"
24 #include "llvm/IR/Function.h"
25 #include "llvm/IR/Instructions.h"
26 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/types.h"
29 #include "tensorflow/core/platform/logging.h"
30 
31 namespace xla {
32 namespace llvm_ir {
33 
ForLoop(absl::string_view prefix,absl::string_view suffix,llvm::Value * start_index,llvm::Value * end_index,llvm::Value * step,UnrollMode unroll_mode,bool prevent_vectorization)34 ForLoop::ForLoop(absl::string_view prefix, absl::string_view suffix,
35                  llvm::Value* start_index, llvm::Value* end_index,
36                  llvm::Value* step, UnrollMode unroll_mode,
37                  bool prevent_vectorization)
38     : prefix_(prefix),
39       suffix_(suffix),
40       start_index_(start_index),
41       end_index_(end_index),
42       step_(step),
43       insert_before_bb_(nullptr),
44       unroll_mode_(unroll_mode),
45       prevent_vectorization_(prevent_vectorization) {}
46 
EmitForLoop(absl::string_view prefix,llvm::Value * start_index,llvm::Value * end_index,llvm::Value * step,llvm::IRBuilder<> * b,UnrollMode unroll_mode,bool prevent_vectorization)47 /* static */ std::unique_ptr<ForLoop> ForLoop::EmitForLoop(
48     absl::string_view prefix, llvm::Value* start_index, llvm::Value* end_index,
49     llvm::Value* step, llvm::IRBuilder<>* b, UnrollMode unroll_mode,
50     bool prevent_vectorization) {
51   std::unique_ptr<ForLoop> loop(new ForLoop(prefix, /*suffix=*/"", start_index,
52                                             end_index, step, unroll_mode,
53                                             prevent_vectorization));
54   loop->Emit(b);
55   return loop;
56 }
57 
Emit(llvm::IRBuilder<> * b)58 void ForLoop::Emit(llvm::IRBuilder<>* b) {
59   // The preheader block is the block the builder is currently emitting
60   // code into.
61   preheader_bb_ = b->GetInsertBlock();
62 
63   llvm::BasicBlock::iterator insert_point = b->GetInsertPoint();
64   if (insert_point == preheader_bb_->end()) {
65     // We're emitting the loop at the end of a basic block. Verify there is no
66     // terminator (eg, branch) in the basic block.
67     CHECK_EQ(nullptr, preheader_bb_->getTerminator());
68 
69     exit_bb_ = CreateLoopBB("loop_exit", b);
70   } else {
71     // We're emitting the loop into the middle of a basic block. splitBasicBlock
72     // requires that this basic block be well-formed (have a terminator).
73     CHECK_NE(nullptr, preheader_bb_->getTerminator());
74 
75     // Split the preheader to create an exit basic block. The exit basic block
76     // will contain all instructions at or after insert_point.
77     exit_bb_ = preheader_bb_->splitBasicBlock(insert_point,
78                                               GetQualifiedName("loop_exit"));
79 
80     // splitBasicBlock adds an unconditional branch between the split basic
81     // blocks. Remove it. An unconditional branch will be added below from the
82     // preheader to the header.
83     preheader_bb_->getTerminator()->eraseFromParent();
84   }
85   insert_before_bb_ = exit_bb_;
86 
87   // Create remaining basic block which form the inside of the loop.
88   header_bb_ = CreateLoopBB("loop_header", b);
89   body_bb_ = CreateLoopBB("loop_body", b);
90 
91   // Function entry basic block.
92   // Emit alloca for the induction variable. We do this at the entry to the
93   // basic block to ensure the alloc only executes once per function (we could
94   // be emitting a nested loop).
95   llvm::Function* func = preheader_bb_->getParent();
96   b->SetInsertPoint(&func->getEntryBlock(),
97                     func->getEntryBlock().getFirstInsertionPt());
98   llvm::Value* indvar_address = b->CreateAlloca(
99       start_index_->getType(), nullptr, GetQualifiedName("invar_address"));
100 
101   // Preheader basic block.
102   // Initialize induction variable starting index. Create branch to the header.
103   b->SetInsertPoint(preheader_bb_);
104   b->CreateStore(start_index_, indvar_address);
105   // The preheader should not have a branch yet.
106   CHECK_EQ(preheader_bb_->getTerminator(), nullptr);
107   b->CreateBr(header_bb_);
108 
109   // Header basic block.
110   // Emit the loop conditional branch. Load and compare indvar with ending
111   // index and jump to loop exit if equal. Jump to body otherwise.
112   b->SetInsertPoint(header_bb_);
113   indvar_ = b->CreateLoad(start_index_->getType(), indvar_address,
114                           GetQualifiedName("indvar"));
115   llvm::Value* exit_cond = b->CreateICmpUGE(indvar_, end_index_);
116   b->CreateCondBr(/*Cond=*/exit_cond,
117                   /*True=*/exit_bb_, /*False=*/body_bb_);
118 
119   // Body basic block.
120   // Increment indvar, store indvar, and jump to header.
121   b->SetInsertPoint(body_bb_);
122   llvm::Value* step = step_;
123   llvm::Value* indvar = indvar_;
124 
125   llvm::Value* indvar_inc = b->CreateAdd(indvar, step, "invar.inc",
126                                          /*HasNUW=*/true, /*HasNSW=*/true);
127   b->CreateStore(indvar_inc, indvar_address);
128   llvm::BranchInst* back_branch = b->CreateBr(header_bb_);
129 
130   std::vector<llvm::Metadata*> loop_metadata = GetLoopMetadata(b);
131   if (!loop_metadata.empty()) {
132     llvm::LLVMContext* ctx = &start_index_->getContext();
133     auto temp_node = llvm::MDNode::getTemporary(*ctx, llvm::None);
134     loop_metadata.insert(loop_metadata.begin(), temp_node.get());
135     auto loop_id = llvm::MDNode::get(*ctx, loop_metadata);
136     loop_id->replaceOperandWith(0, loop_id);
137     back_branch->setMetadata(llvm::LLVMContext::MD_loop, loop_id);
138   }
139 
140   // Re-point the IR builder to the loop exit block.
141   b->SetInsertPoint(exit_bb_);
142 }
143 
GetLoopMetadata(llvm::IRBuilder<> * b)144 std::vector<llvm::Metadata*> ForLoop::GetLoopMetadata(llvm::IRBuilder<>* b) {
145   const char* const kLlvmLoopUnrollDisableMDName = "llvm.loop.unroll.disable";
146   const char* const kLlvmLoopUnrollFullMDName = "llvm.loop.unroll.full";
147   const char* const kLlvmLoopVectorizeMDName = "llvm.loop.vectorize.enable";
148   llvm::LLVMContext* ctx = &start_index_->getContext();
149 
150   std::vector<llvm::Metadata*> result;
151   if (unroll_mode_ == xla::llvm_ir::UnrollMode::kNoUnroll) {
152     result.push_back(llvm::MDNode::get(
153         *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollDisableMDName)}));
154   }
155 
156   if (prevent_vectorization_) {
157     result.push_back(llvm::MDNode::get(
158         *ctx, {llvm::MDString::get(*ctx, kLlvmLoopVectorizeMDName),
159                llvm::ConstantAsMetadata::get(b->getFalse())}));
160   }
161 
162   if (unroll_mode_ == xla::llvm_ir::UnrollMode::kFullyUnroll) {
163     result.push_back(llvm::MDNode::get(
164         *ctx, {llvm::MDString::get(*ctx, kLlvmLoopUnrollFullMDName)}));
165   }
166   return result;
167 }
168 
GetQualifiedName(absl::string_view name)169 std::string ForLoop::GetQualifiedName(absl::string_view name) {
170   return llvm_ir::IrName(prefix_, llvm_ir::IrName(name, suffix_));
171 }
172 
CreateLoopBB(absl::string_view name,llvm::IRBuilder<> * b)173 llvm::BasicBlock* ForLoop::CreateLoopBB(absl::string_view name,
174                                         llvm::IRBuilder<>* b) {
175   return CreateBasicBlock(insert_before_bb_, GetQualifiedName(name), b);
176 }
177 
AddLoop(absl::string_view suffix,llvm::Value * start_index,llvm::Value * end_index,UnrollMode unroll_mode,bool prevent_vectorization)178 std::unique_ptr<ForLoop> ForLoopNest::AddLoop(absl::string_view suffix,
179                                               llvm::Value* start_index,
180                                               llvm::Value* end_index,
181                                               UnrollMode unroll_mode,
182                                               bool prevent_vectorization) {
183   return AddLoop(suffix, start_index, end_index, GetConstantWithIndexType(1),
184                  unroll_mode, prevent_vectorization);
185 }
186 
AddLoop(absl::string_view suffix,llvm::Value * start_index,llvm::Value * end_index,llvm::Value * stride,UnrollMode unroll_mode,bool prevent_vectorization)187 std::unique_ptr<ForLoop> ForLoopNest::AddLoop(
188     absl::string_view suffix, llvm::Value* start_index, llvm::Value* end_index,
189     llvm::Value* stride, UnrollMode unroll_mode, bool prevent_vectorization) {
190   if (inner_loop_body_bb_ != nullptr) {
191     // Create this loop inside the previous one.
192     b_->SetInsertPoint(&*inner_loop_body_bb_->getFirstInsertionPt());
193   }
194   std::unique_ptr<ForLoop> loop(new ForLoop(
195       /*prefix=*/name_, suffix, start_index, end_index, stride, unroll_mode,
196       prevent_vectorization));
197   loop->Emit(b_);
198 
199   if (outer_loop_preheader_bb_ == nullptr) {
200     outer_loop_preheader_bb_ = loop->GetPreheaderBasicBlock();
201   }
202 
203   if (outer_loop_exit_bb_ == nullptr) {
204     outer_loop_exit_bb_ = loop->GetExitBasicBlock();
205   }
206 
207   inner_loop_body_bb_ = loop->GetBodyBasicBlock();
208 
209   return loop;
210 }
211 
AddLoop(int64_t start_index,int64_t end_index,absl::string_view suffix,UnrollMode unroll_mode,bool prevent_vectorization)212 std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64_t start_index,
213                                               int64_t end_index,
214                                               absl::string_view suffix,
215                                               UnrollMode unroll_mode,
216                                               bool prevent_vectorization) {
217   CHECK_LE(start_index, end_index);
218   return AddLoop(suffix, GetConstantWithIndexType(start_index),
219                  GetConstantWithIndexType(end_index), unroll_mode,
220                  prevent_vectorization);
221 }
222 
AddLoop(int64_t start_index,int64_t end_index,int64_t stride,absl::string_view suffix,UnrollMode unroll_mode,bool prevent_vectorization)223 std::unique_ptr<ForLoop> ForLoopNest::AddLoop(int64_t start_index,
224                                               int64_t end_index, int64_t stride,
225                                               absl::string_view suffix,
226                                               UnrollMode unroll_mode,
227                                               bool prevent_vectorization) {
228   CHECK_LE(start_index, end_index);
229   return AddLoop(suffix, GetConstantWithIndexType(start_index),
230                  GetConstantWithIndexType(end_index),
231                  GetConstantWithIndexType(stride), unroll_mode,
232                  prevent_vectorization);
233 }
234 
AddLoopsForShape(const Shape & shape,absl::string_view suffix)235 IrArray::Index ForLoopNest::AddLoopsForShape(const Shape& shape,
236                                              absl::string_view suffix) {
237   std::vector<int64_t> dimensions(shape.rank());
238   std::iota(dimensions.begin(), dimensions.end(), 0);
239   return IrArray::Index(AddLoopsForShapeOnDimensions(shape, dimensions, suffix),
240                         shape, index_type_);
241 }
242 
AddLoopsForShapeOnDimensions(const Shape & shape,absl::Span<const int64_t> dimensions,absl::string_view suffix)243 std::vector<llvm::Value*> ForLoopNest::AddLoopsForShapeOnDimensions(
244     const Shape& shape, absl::Span<const int64_t> dimensions,
245     absl::string_view suffix) {
246   std::vector<llvm::Value*> multi_index(shape.dimensions_size());
247   for (int64_t dimension : dimensions) {
248     std::unique_ptr<llvm_ir::ForLoop> loop = AddLoop(
249         /*start_index=*/0,
250         /*end_index=*/shape.dimensions(dimension),
251         /*suffix=*/
252         llvm_ir::IrName(suffix, absl::StrCat(dimension)));
253     multi_index[dimension] = loop->GetIndVarValue();
254   }
255   return multi_index;
256 }
257 
EmitOperandArrayLoopNest(const llvm_ir::IrArray & operand_array,int64_t dimension_to_skip,absl::string_view name_suffix)258 std::vector<llvm::Value*> ForLoopNest::EmitOperandArrayLoopNest(
259     const llvm_ir::IrArray& operand_array, int64_t dimension_to_skip,
260     absl::string_view name_suffix) {
261   // Prepares the dimension list we will use to emit the loop nest. Outermost
262   // loops are added first. Add loops in major-to-minor order, and skip the
263   // 'dimension_to_skip' dimension.
264   std::vector<int64_t> dimensions;
265   const Shape& shape = operand_array.GetShape();
266   // Initially get the dimensions in minor to major order, then reverse them.
267   for (int64_t dimension : LayoutUtil::MinorToMajor(shape)) {
268     if (dimension != dimension_to_skip) {
269       dimensions.push_back(dimension);
270     }
271   }
272   absl::c_reverse(dimensions);
273 
274   // Create loop nest with one for-loop for each dimension of the
275   // output.
276   std::vector<llvm::Value*> multi_index =
277       AddLoopsForShapeOnDimensions(shape, dimensions, name_suffix);
278   // Verify every dimension except the 'dimension_to_skip' dimension was set in
279   // the index.
280   for (size_t dimension = 0; dimension < multi_index.size(); ++dimension) {
281     if (dimension == dimension_to_skip) {
282       DCHECK_EQ(nullptr, multi_index[dimension]);
283     } else {
284       DCHECK_NE(nullptr, multi_index[dimension]);
285     }
286   }
287   return multi_index;
288 }
289 
290 }  // namespace llvm_ir
291 }  // namespace xla
292