• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
17 
18 #include <memory>
19 #include <utility>
20 
21 #include "absl/strings/str_format.h"
22 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25 #include "tensorflow/compiler/xla/types.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/protobuf.h"
29 #include "tensorflow/core/platform/types.h"
30 
31 namespace xla {
32 namespace llvm_ir {
33 
LoopEmitter(const BodyEmitter & body_emitter,const Shape & shape,llvm::IRBuilder<> * b)34 LoopEmitter::LoopEmitter(const BodyEmitter& body_emitter, const Shape& shape,
35                          llvm::IRBuilder<>* b)
36     : body_emitter_(body_emitter), shape_(shape), b_(b) {}
37 
LoopEmitter(const BodyEmitter & body_emitter,const Shape & shape,std::vector<llvm::Value * > dynamic_dims,llvm::IRBuilder<> * b)38 LoopEmitter::LoopEmitter(const BodyEmitter& body_emitter, const Shape& shape,
39                          std::vector<llvm::Value*> dynamic_dims,
40                          llvm::IRBuilder<>* b)
41     : LoopEmitter::LoopEmitter(body_emitter, shape, b) {
42   CHECK_EQ(dynamic_dims.size(), shape_.dimensions_size());
43   dynamic_dims_ = std::move(dynamic_dims);
44 }
45 
LoopEmitter(const ElementGenerator & target_element_generator,const IrArray & target_array,llvm::IRBuilder<> * b)46 LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator,
47                          const IrArray& target_array, llvm::IRBuilder<>* b)
48     : body_emitter_([=](const llvm_ir::IrArray::Index array_index) -> Status {
49         // Convert target_element_generator to a BodyEmitter.
50         TF_ASSIGN_OR_RETURN(llvm::Value * target_element,
51                             target_element_generator(array_index));
52         target_array.EmitWriteArrayElement(array_index, target_element, b);
53         return Status::OK();
54       }),
55       shape_(target_array.GetShape()),
56       b_(b) {}
57 
MakeBodyEmitterForMultiOutput(const ElementGenerator & target_element_generator,const std::vector<IrArray> & target_arrays,llvm::IRBuilder<> * b)58 static LoopEmitter::BodyEmitter MakeBodyEmitterForMultiOutput(
59     const ElementGenerator& target_element_generator,
60     const std::vector<IrArray>& target_arrays, llvm::IRBuilder<>* b) {
61   return [=](const llvm_ir::IrArray::Index array_index) {
62     TF_ASSIGN_OR_RETURN(llvm::Value * target_element,
63                         target_element_generator(array_index));
64     CHECK(target_element->getType()->isStructTy())
65         << "This BodyEmitter is for multi-output, but target element "
66            "generator does not produce values of struct type.";
67     CHECK_EQ(target_element->getType()->getStructNumElements(),
68              target_arrays.size());
69 
70     for (int64 i = 0; i < target_arrays.size(); ++i) {
71       target_arrays[i].EmitWriteArrayElement(
72           array_index, b->CreateExtractValue(target_element, i), b);
73     }
74     return Status::OK();
75   };
76 }
77 
LoopEmitter(const ElementGenerator & target_element_generator,absl::Span<const IrArray> target_arrays,llvm::IRBuilder<> * b)78 LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator,
79                          absl::Span<const IrArray> target_arrays,
80                          llvm::IRBuilder<>* b)
81     : body_emitter_(MakeBodyEmitterForMultiOutput(
82           target_element_generator,
83           std::vector<IrArray>(target_arrays.begin(), target_arrays.end()), b)),
84       shape_(target_arrays[0].GetShape()),
85       b_(b) {
86   // Sanity check: In multi-output fusion, all shapes produced must have the
87   // same dimensions.
88   for (const IrArray& array : target_arrays) {
89     CHECK(ShapeUtil::SameDimensions(shape_, array.GetShape()))
90         << ": '" << shape_.ShortDebugString() << "' does not match '"
91         << array.GetShape().ShortDebugString() << "'";
92   }
93 }
94 
EmitStaticIndex(ForLoopNest * loop_nest,llvm::Type * index_type)95 IrArray::Index LoopEmitter::EmitStaticIndex(ForLoopNest* loop_nest,
96                                             llvm::Type* index_type) {
97   // Create loop nest with one for-loop for each dimension of the target shape.
98   // Loops are added from outermost to innermost order with the ForLoopNest
99   // class so emit loops in order from most-major dimension down to most-minor
100   // dimension (of the target shape).
101   std::vector<llvm::Value*> array_multi_index(shape_.dimensions_size());
102   for (int i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) {
103     int64 dimension = LayoutUtil::Major(shape_.layout(), i);
104     std::unique_ptr<ForLoop> loop = loop_nest->AddLoop(
105         /*start_index=*/0,
106         /*end_index=*/shape_.dimensions(dimension),
107         /*suffix=*/absl::StrFormat("dim.%d", dimension));
108     array_multi_index[dimension] = loop->GetIndVarValue();
109   }
110   return IrArray::Index(array_multi_index, shape_, index_type);
111 }
112 
EmitDynamicIndex(ForLoopNest * loop_nest,llvm::Type * index_type)113 IrArray::Index LoopEmitter::EmitDynamicIndex(ForLoopNest* loop_nest,
114                                              llvm::Type* index_type) {
115   CHECK_EQ(shape_.is_dynamic(), true);
116   // Create loop nest with one for-loop for each dynamic dimensions.
117   // Loops are added from outermost to innermost order with the ForLoopNest
118   // class so emit loops in order from most-major dimension down to most-minor
119   // dimension (of the target shape).
120   std::vector<llvm::Value*> array_multi_index(shape_.dimensions_size());
121   for (int i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) {
122     int64 dimension = LayoutUtil::Major(shape_.layout(), i);
123     std::unique_ptr<ForLoop> loop = loop_nest->AddLoop(
124         /*suffix=*/absl::StrFormat("dim.%d", dimension),
125         /*start_index=*/llvm::ConstantInt::get(index_type, 0),
126         /*end_index=*/dynamic_dims_[dimension]);
127     array_multi_index[dimension] = loop->GetIndVarValue();
128   }
129   return IrArray::Index(array_multi_index, shape_, index_type);
130 }
131 
EmitIndexAndSetExitBasicBlock(absl::string_view loop_name,llvm::Type * index_type,llvm::Value * base_index)132 std::vector<IrArray::Index> LoopEmitter::EmitIndexAndSetExitBasicBlock(
133     absl::string_view loop_name, llvm::Type* index_type,
134     llvm::Value* base_index) {
135   CHECK_NE(index_type, nullptr);
136   CHECK_EQ(base_index, nullptr)
137       << "XLA CPU implementation of"
138       << " LoopEmitter::EmitIndexAndSetExitBasicBlock doesn't support"
139       << " base_index, but it was requested.";
140 
141   if (ShapeUtil::IsScalar(shape_)) {
142     // No loop needed, so set exit_bb_ to nullptr.
143     exit_bb_ = nullptr;
144     return {IrArray::Index(index_type)};
145   }
146 
147   ForLoopNest loop_nest(loop_name, b_);
148 
149   IrArray::Index array_index = dynamic_dims_.empty()
150                                    ? EmitStaticIndex(&loop_nest, index_type)
151                                    : EmitDynamicIndex(&loop_nest, index_type);
152 
153   // Set IR builder insertion point to the loop body basic block of the
154   // innermost loop.
155   llvm::BasicBlock* innermost_body_bb = loop_nest.GetInnerLoopBodyBasicBlock();
156   b_->SetInsertPoint(innermost_body_bb,
157                      innermost_body_bb->getFirstInsertionPt());
158 
159   // Set exit_bb_ to the exit block of the loop nest.
160   exit_bb_ = loop_nest.GetOuterLoopExitBasicBlock();
161   CHECK_NOTNULL(exit_bb_);
162 
163   return {array_index};
164 }
165 
EmitLoop(absl::string_view loop_name,llvm::Type * index_type)166 Status LoopEmitter::EmitLoop(absl::string_view loop_name,
167                              llvm::Type* index_type) {
168   if (index_type == nullptr) {
169     index_type = b_->getInt64Ty();
170   }
171 
172   for (const IrArray::Index& array_index :
173        EmitIndexAndSetExitBasicBlock(loop_name, index_type,
174                                      /*base_index*/ nullptr)) {
175     TF_RETURN_IF_ERROR(body_emitter_(array_index));
176   }
177 
178   // Set the insertion point of b_ to the loop exit, so that
179   // code emitted for later instructions will be correctly placed.
180   if (exit_bb_ != nullptr) {
181     b_->SetInsertPoint(exit_bb_);
182   }
183   return Status::OK();
184 }
185 
186 }  // namespace llvm_ir
187 }  // namespace xla
188