1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
17
18 #include <memory>
19 #include <utility>
20
21 #include "absl/strings/str_format.h"
22 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
23 #include "tensorflow/compiler/xla/shape_util.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25 #include "tensorflow/compiler/xla/types.h"
26 #include "tensorflow/core/lib/core/errors.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/protobuf.h"
29 #include "tensorflow/core/platform/types.h"
30
31 namespace xla {
32 namespace llvm_ir {
33
LoopEmitter(const BodyEmitter & body_emitter,const Shape & shape,llvm::IRBuilder<> * b)34 LoopEmitter::LoopEmitter(const BodyEmitter& body_emitter, const Shape& shape,
35 llvm::IRBuilder<>* b)
36 : body_emitter_(body_emitter), shape_(shape), b_(b) {}
37
LoopEmitter(const BodyEmitter & body_emitter,const Shape & shape,std::vector<llvm::Value * > dynamic_dims,llvm::IRBuilder<> * b)38 LoopEmitter::LoopEmitter(const BodyEmitter& body_emitter, const Shape& shape,
39 std::vector<llvm::Value*> dynamic_dims,
40 llvm::IRBuilder<>* b)
41 : LoopEmitter::LoopEmitter(body_emitter, shape, b) {
42 CHECK_EQ(dynamic_dims.size(), shape_.dimensions_size());
43 dynamic_dims_ = std::move(dynamic_dims);
44 }
45
LoopEmitter(const ElementGenerator & target_element_generator,const IrArray & target_array,llvm::IRBuilder<> * b)46 LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator,
47 const IrArray& target_array, llvm::IRBuilder<>* b)
48 : body_emitter_([=](const llvm_ir::IrArray::Index array_index) -> Status {
49 // Convert target_element_generator to a BodyEmitter.
50 TF_ASSIGN_OR_RETURN(llvm::Value * target_element,
51 target_element_generator(array_index));
52 target_array.EmitWriteArrayElement(array_index, target_element, b);
53 return Status::OK();
54 }),
55 shape_(target_array.GetShape()),
56 b_(b) {}
57
MakeBodyEmitterForMultiOutput(const ElementGenerator & target_element_generator,const std::vector<IrArray> & target_arrays,llvm::IRBuilder<> * b)58 static LoopEmitter::BodyEmitter MakeBodyEmitterForMultiOutput(
59 const ElementGenerator& target_element_generator,
60 const std::vector<IrArray>& target_arrays, llvm::IRBuilder<>* b) {
61 return [=](const llvm_ir::IrArray::Index array_index) {
62 TF_ASSIGN_OR_RETURN(llvm::Value * target_element,
63 target_element_generator(array_index));
64 CHECK(target_element->getType()->isStructTy())
65 << "This BodyEmitter is for multi-output, but target element "
66 "generator does not produce values of struct type.";
67 CHECK_EQ(target_element->getType()->getStructNumElements(),
68 target_arrays.size());
69
70 for (int64_t i = 0; i < target_arrays.size(); ++i) {
71 target_arrays[i].EmitWriteArrayElement(
72 array_index, b->CreateExtractValue(target_element, i), b);
73 }
74 return Status::OK();
75 };
76 }
77
LoopEmitter(const ElementGenerator & target_element_generator,absl::Span<const IrArray> target_arrays,llvm::IRBuilder<> * b)78 LoopEmitter::LoopEmitter(const ElementGenerator& target_element_generator,
79 absl::Span<const IrArray> target_arrays,
80 llvm::IRBuilder<>* b)
81 : body_emitter_(MakeBodyEmitterForMultiOutput(
82 target_element_generator,
83 std::vector<IrArray>(target_arrays.begin(), target_arrays.end()), b)),
84 shape_(target_arrays[0].GetShape()),
85 b_(b) {
86 // Sanity check: In multi-output fusion, all shapes produced must have the
87 // same dimensions.
88 for (const IrArray& array : target_arrays) {
89 CHECK(ShapeUtil::SameDimensions(shape_, array.GetShape()))
90 << ": '" << shape_.ShortDebugString() << "' does not match '"
91 << array.GetShape().ShortDebugString() << "'";
92 }
93 }
94
EmitStaticIndex(ForLoopNest * loop_nest,llvm::Type * index_type)95 IrArray::Index LoopEmitter::EmitStaticIndex(ForLoopNest* loop_nest,
96 llvm::Type* index_type) {
97 // Create loop nest with one for-loop for each dimension of the target shape.
98 // Loops are added from outermost to innermost order with the ForLoopNest
99 // class so emit loops in order from most-major dimension down to most-minor
100 // dimension (of the target shape).
101 std::vector<llvm::Value*> array_multi_index(shape_.dimensions_size());
102 for (int i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) {
103 int64_t dimension = LayoutUtil::Major(shape_.layout(), i);
104 std::unique_ptr<ForLoop> loop = loop_nest->AddLoop(
105 /*start_index=*/0,
106 /*end_index=*/shape_.dimensions(dimension),
107 /*suffix=*/absl::StrFormat("dim.%d", dimension));
108 array_multi_index[dimension] = loop->GetIndVarValue();
109 }
110 return IrArray::Index(array_multi_index, shape_, index_type);
111 }
112
EmitDynamicIndex(ForLoopNest * loop_nest,llvm::Type * index_type)113 IrArray::Index LoopEmitter::EmitDynamicIndex(ForLoopNest* loop_nest,
114 llvm::Type* index_type) {
115 CHECK_EQ(shape_.is_dynamic(), true);
116 // Create loop nest with one for-loop for each dynamic dimensions.
117 // Loops are added from outermost to innermost order with the ForLoopNest
118 // class so emit loops in order from most-major dimension down to most-minor
119 // dimension (of the target shape).
120 std::vector<llvm::Value*> array_multi_index(shape_.dimensions_size());
121 for (int i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) {
122 int64_t dimension = LayoutUtil::Major(shape_.layout(), i);
123 std::unique_ptr<ForLoop> loop = loop_nest->AddLoop(
124 /*suffix=*/absl::StrFormat("dim.%d", dimension),
125 /*start_index=*/llvm::ConstantInt::get(index_type, 0),
126 /*end_index=*/dynamic_dims_[dimension]);
127 array_multi_index[dimension] = loop->GetIndVarValue();
128 }
129 return IrArray::Index(array_multi_index, shape_, index_type);
130 }
131
EmitIndexAndSetExitBasicBlock(absl::string_view loop_name,llvm::Type * index_type,llvm::Value * base_index)132 std::vector<IrArray::Index> LoopEmitter::EmitIndexAndSetExitBasicBlock(
133 absl::string_view loop_name, llvm::Type* index_type,
134 llvm::Value* base_index) {
135 CHECK_NE(index_type, nullptr);
136 CHECK_EQ(base_index, nullptr)
137 << "XLA CPU implementation of"
138 << " LoopEmitter::EmitIndexAndSetExitBasicBlock doesn't support"
139 << " base_index, but it was requested.";
140
141 if (ShapeUtil::IsScalar(shape_)) {
142 // No loop needed, so set exit_bb_ to nullptr.
143 exit_bb_ = nullptr;
144 return {IrArray::Index(index_type)};
145 }
146
147 ForLoopNest loop_nest(loop_name, b_);
148
149 IrArray::Index array_index = dynamic_dims_.empty()
150 ? EmitStaticIndex(&loop_nest, index_type)
151 : EmitDynamicIndex(&loop_nest, index_type);
152
153 // Set IR builder insertion point to the loop body basic block of the
154 // innermost loop.
155 llvm::BasicBlock* innermost_body_bb = loop_nest.GetInnerLoopBodyBasicBlock();
156 b_->SetInsertPoint(innermost_body_bb,
157 innermost_body_bb->getFirstInsertionPt());
158
159 // Set exit_bb_ to the exit block of the loop nest.
160 exit_bb_ = loop_nest.GetOuterLoopExitBasicBlock();
161 CHECK_NOTNULL(exit_bb_);
162
163 return {array_index};
164 }
165
EmitLoop(absl::string_view loop_name,llvm::Type * index_type)166 Status LoopEmitter::EmitLoop(absl::string_view loop_name,
167 llvm::Type* index_type) {
168 if (index_type == nullptr) {
169 index_type = b_->getInt64Ty();
170 }
171
172 for (const IrArray::Index& array_index :
173 EmitIndexAndSetExitBasicBlock(loop_name, index_type,
174 /*base_index*/ nullptr)) {
175 TF_RETURN_IF_ERROR(body_emitter_(array_index));
176 }
177
178 // Set the insertion point of b_ to the loop exit, so that
179 // code emitted for later instructions will be correctly placed.
180 if (exit_bb_ != nullptr) {
181 b_->SetInsertPoint(exit_bb_);
182 }
183 return Status::OK();
184 }
185
186 } // namespace llvm_ir
187 } // namespace xla
188