1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
17
18 #include "llvm/IR/Constants.h"
19 #include "llvm/IR/Instructions.h"
20 #include "tensorflow/compiler/xla/layout_util.h"
21 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/compiler/xla/statusor.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/types.h"
28
29 namespace xla {
30 namespace llvm_ir {
31
Index(llvm::Value * linear,const Shape & shape,llvm::IRBuilder<> * ir_builder)32 IrArray::Index::Index(llvm::Value* linear, const Shape& shape,
33 llvm::IRBuilder<>* ir_builder)
34 : multidim_(ShapeUtil::Rank(shape)),
35 linear_(linear),
36 layout_(shape.layout()),
37 dims_(shape.dimensions().begin(), shape.dimensions().end()) {
38 CHECK(LayoutUtil::HasLayout(shape))
39 << "Shape " << ShapeUtil::HumanStringWithLayout(shape)
40 << " should have a layout.";
41 int64 divisor = 1;
42 for (int64 i = 0; i < layout_.minor_to_major_size(); ++i) {
43 int64 dimension = layout_.minor_to_major(i);
44 int64 size_of_current_dimension = shape.dimensions(dimension);
45
46 // If i is not the last dimension, compute
47 // (linear_index / divisor) % current_dimension.
48 // If i is the last dimension, we can skip the mod, because we assume that
49 // linear is in bounds.
50 //
51 // TODO(jlebar): We could add bounds checks here and elsewhere in this file,
52 // guarded under some sort of xla-memcheck flag. This might be particularly
53 // useful because cuda-memcheck can't help us much in XLA: Most of our
54 // memory lives in one big allocation, so cuda-memcheck can't detect
55 // out-of-bounds accesses.
56 auto* quot = ir_builder->CreateUDiv(linear, ir_builder->getInt64(divisor));
57 if (i < layout_.minor_to_major_size() - 1) {
58 multidim_[dimension] = ir_builder->CreateURem(
59 quot, ir_builder->getInt64(size_of_current_dimension));
60 } else {
61 multidim_[dimension] = quot;
62 }
63 divisor *= size_of_current_dimension;
64 }
65 }
66
Index(tensorflow::gtl::ArraySlice<llvm::Value * > multidim,llvm::Value * linear,const Shape & shape)67 IrArray::Index::Index(tensorflow::gtl::ArraySlice<llvm::Value*> multidim,
68 llvm::Value* linear, const Shape& shape)
69 : multidim_(multidim.begin(), multidim.end()),
70 linear_(linear),
71 layout_(shape.layout()),
72 dims_(shape.dimensions().begin(), shape.dimensions().end()) {
73 CHECK_EQ(shape.dimensions_size(), multidim.size());
74 CHECK(LayoutUtil::HasLayout(shape))
75 << "Shape " << ShapeUtil::HumanStringWithLayout(shape)
76 << " should have a layout.";
77 }
78
Index(tensorflow::gtl::ArraySlice<llvm::Value * > multidim,const Shape & shape,llvm::IRBuilder<> * ir_builder)79 IrArray::Index::Index(tensorflow::gtl::ArraySlice<llvm::Value*> multidim,
80 const Shape& shape, llvm::IRBuilder<>* ir_builder)
81 : multidim_(multidim.begin(), multidim.end()),
82 layout_(shape.layout()),
83 dims_(shape.dimensions().begin(), shape.dimensions().end()) {
84 CHECK_EQ(shape.dimensions_size(), multidim.size());
85 CHECK(LayoutUtil::HasLayout(shape));
86 linear_ = Linearize(AsInt64Slice(shape.dimensions()), ir_builder);
87 }
88
IrArray(llvm::Value * base_ptr,const Shape & shape)89 IrArray::IrArray(llvm::Value* base_ptr, const Shape& shape)
90 : base_ptr_(base_ptr), shape_(&shape) {
91 TF_CHECK_OK(ShapeUtil::ValidateShape(shape));
92 CHECK(base_ptr_->getType()->isPointerTy());
93 int depth = 0;
94 element_type_ =
95 llvm::cast<llvm::PointerType>(base_ptr_->getType())->getElementType();
96 while (llvm::ArrayType* array_type =
97 llvm::dyn_cast<llvm::ArrayType>(element_type_)) {
98 element_type_ = array_type->getElementType();
99 ++depth;
100 }
101
102 if (!ShapeUtil::IsArray(*shape_) || ShapeUtil::IsScalar(*shape_)) {
103 DCHECK(depth == 1 || depth == 0) << depth;
104 } else {
105 DCHECK_EQ(depth, ShapeUtil::Rank(*shape_)) << shape.ShortDebugString();
106 }
107 }
108
109 // Returns whether given linear index valid on given shape.
LinearValidOnShape(const Shape & a) const110 bool IrArray::Index::LinearValidOnShape(const Shape& a) const {
111 auto b = ShapeUtil::MakeShape(PRED /* irrelevant */, dims_);
112 *b.mutable_layout() = layout_;
113 return linear_ != nullptr &&
114 ContainersEqual(
115 ShapeUtil::StripDegenerateDimensions(a).dimensions(),
116 ShapeUtil::StripDegenerateDimensions(b).dimensions()) &&
117 LayoutUtil::Equal(ShapeUtil::StripDegenerateDimensions(a).layout(),
118 ShapeUtil::StripDegenerateDimensions(b).layout());
119 }
120
SourceIndexOfReshape(const Shape & output_shape,const Shape & input_shape,llvm::IRBuilder<> * builder) const121 IrArray::Index IrArray::Index::SourceIndexOfReshape(
122 const Shape& output_shape, const Shape& input_shape,
123 llvm::IRBuilder<>* builder) const {
124 const auto& target_index = *this;
125 CHECK_EQ(target_index.size(), ShapeUtil::Rank(output_shape));
126 std::vector<std::pair<int64, int64>> common_factors =
127 CommonFactors(AsInt64Slice(input_shape.dimensions()),
128 AsInt64Slice(output_shape.dimensions()));
129 std::vector<llvm::Value*> source_multidim_index(
130 ShapeUtil::Rank(input_shape),
131 llvm::UndefValue::get(builder->getInt64Ty()));
132 // We compute the source indices in each common factor from only the target
133 // indices in the same common factor.
134 for (ssize_t k = common_factors.size() - 2; k >= 0; --k) {
135 llvm::Value* logical_linear_index =
136 Index(tensorflow::gtl::ArraySlice<llvm::Value*>(
137 multidim_, common_factors[k].second,
138 common_factors[k + 1].second - common_factors[k].second))
139 .Linearize(
140 tensorflow::gtl::ArraySlice<int64>(
141 AsInt64Slice(output_shape.dimensions()),
142 common_factors[k].second,
143 common_factors[k + 1].second - common_factors[k].second),
144 builder);
145 // Delinearizes logical_linear_index for the source array in row-major
146 // collapsed order. The first rank-1 indices are the remainder of the
147 // linear index by each dimension size.
148 for (int64 i = common_factors[k + 1].first - 1;
149 i >= common_factors[k].first; --i) {
150 llvm::Value* divisor = builder->getInt64(input_shape.dimensions(i));
151 if (input_shape.dimensions(i) == 1) {
152 source_multidim_index[i] = builder->getInt64(0);
153 } else if (i == common_factors[k].first) {
154 source_multidim_index[i] = logical_linear_index;
155 } else {
156 source_multidim_index[i] =
157 builder->CreateURem(logical_linear_index, divisor);
158 }
159 logical_linear_index = builder->CreateUDiv(logical_linear_index, divisor);
160 }
161 }
162
163 if (linear() != nullptr &&
164 ShapeUtil::ReshapeIsBitcast(input_shape, output_shape)) {
165 return Index(source_multidim_index, linear(), input_shape);
166 }
167 return Index(source_multidim_index);
168 }
169
SourceIndexOfSlice(const Shape & shape,tensorflow::gtl::ArraySlice<int64> starts,tensorflow::gtl::ArraySlice<int64> strides,llvm::IRBuilder<> * builder) const170 IrArray::Index IrArray::Index::SourceIndexOfSlice(
171 const Shape& shape, tensorflow::gtl::ArraySlice<int64> starts,
172 tensorflow::gtl::ArraySlice<int64> strides,
173 llvm::IRBuilder<>* builder) const {
174 Index source_index(multidim_.size());
175 for (int i = 0; i < multidim_.size(); ++i) {
176 int64 stride = strides[i];
177 auto type = multidim_[i]->getType();
178
179 if (stride != 1) {
180 source_index[i] = builder->CreateAdd(
181 builder->CreateMul(multidim_[i],
182 llvm::ConstantInt::get(type, stride)),
183 llvm::ConstantInt::get(type, starts[i]));
184 } else {
185 source_index[i] = builder->CreateAdd(
186 multidim_[i], llvm::ConstantInt::get(type, starts[i]));
187 }
188 }
189 return source_index;
190 }
191
SourceIndexOfTranspose(const Shape & shape,const Shape & operand_shape,tensorflow::gtl::ArraySlice<int64> dimension_mapping,llvm::IRBuilder<> * builder) const192 IrArray::Index IrArray::Index::SourceIndexOfTranspose(
193 const Shape& shape, const Shape& operand_shape,
194 tensorflow::gtl::ArraySlice<int64> dimension_mapping,
195 llvm::IRBuilder<>* builder) const {
196 std::vector<llvm::Value*> operand_multidim_index =
197 Permute(dimension_mapping, multidim());
198 if (linear() != nullptr &&
199 ShapeUtil::TransposeIsBitcast(operand_shape, shape, dimension_mapping)) {
200 return Index(operand_multidim_index, linear(), operand_shape);
201 }
202 return Index(operand_multidim_index);
203 }
204
Linearize(tensorflow::gtl::ArraySlice<int64> dimensions,llvm::IRBuilder<> * builder) const205 llvm::Value* IrArray::Index::Linearize(
206 tensorflow::gtl::ArraySlice<int64> dimensions,
207 llvm::IRBuilder<>* builder) const {
208 // Each dimension is multiplied by the product of the sizes of all
209 // earlier dimensions and added to the accumulator logical_linear_index.
210 llvm::Value* logical_linear_index = builder->getInt64(0);
211 int64 multiplier = 1;
212 for (ssize_t i = size() - 1; i >= 0; --i) {
213 llvm::Value* addend =
214 builder->CreateMul((*this)[i], builder->getInt64(multiplier), "",
215 /*HasNUW=*/true, /*HasNSW=*/true);
216 logical_linear_index = builder->CreateAdd(logical_linear_index, addend, "",
217 /*HasNUW=*/true, /*HasNSW=*/true);
218 multiplier *= dimensions[i];
219 }
220 return logical_linear_index;
221 }
222
EmitArrayElementAddress(const IrArray::Index & index,llvm::IRBuilder<> * ir_builder,tensorflow::StringPiece name) const223 llvm::Value* IrArray::EmitArrayElementAddress(
224 const IrArray::Index& index, llvm::IRBuilder<>* ir_builder,
225 tensorflow::StringPiece name) const {
226 if (ShapeUtil::IsScalar(*shape_)) {
227 // Special handling of scalars: a scalar pretends to have the same value for
228 // every index, thus effectively implementing broadcasting of its value
229 // over higher-rank arrays.
230 return base_ptr_;
231 }
232 CHECK_EQ(index.size(), ShapeUtil::Rank(*shape_));
233
234 std::vector<llvm::Value*> actual_index;
235 bool is_implicit_broadcast = false;
236 // We perform broadcasting when the operand shape has dimension(s) of size
237 // 1. In this case we fix the index value for that dimension to zero. This
238 // effectively broadcasts along this dimension.
239 for (int64 i = 0; i < index.size(); ++i) {
240 auto dim = shape_->dimensions(i);
241 actual_index.push_back(dim == 1 ? ir_builder->getInt64(0) : index[i]);
242 is_implicit_broadcast |= dim == 1;
243 }
244
245 if (!is_implicit_broadcast && index.LinearValidOnShape(*shape_)) {
246 llvm::Module* module =
247 ir_builder->GetInsertBlock()->getParent()->getParent();
248 return ir_builder->CreateInBoundsGEP(
249 ir_builder->CreateBitCast(
250 base_ptr_, PrimitiveTypeToIrType(shape_->element_type(), module)
251 ->getPointerTo()),
252 {index.linear()}, llvm_ir::AsStringRef(name));
253 }
254
255 // "base_ptr_" has the type of "<ir_type_for_its_shape>*"
256 // (e.g. [3 x [2 x float]]*). Therefore, the address of the indexed element
257 // should be computed by
258 //
259 // getelementptr base_ptr_, 0, most major index, ..., most minor index
260 std::vector<llvm::Value*> gep_indices(1, ir_builder->getInt64(0));
261 for (int64 i = 0; i < LayoutUtil::MinorToMajor(*shape_).size(); ++i) {
262 int64 dimension = LayoutUtil::Major(shape_->layout(), i);
263 gep_indices.push_back(actual_index[dimension]);
264 }
265 return ir_builder->CreateInBoundsGEP(base_ptr_, gep_indices,
266 llvm_ir::AsStringRef(name));
267 }
268
AnnotateLoadStoreInstructionWithMetadata(llvm::Instruction * instruction) const269 void IrArray::AnnotateLoadStoreInstructionWithMetadata(
270 llvm::Instruction* instruction) const {
271 CHECK(llvm::isa<llvm::LoadInst>(instruction) ||
272 llvm::isa<llvm::StoreInst>(instruction));
273 CHECK(!llvm::isa<llvm::StoreInst>(instruction) || !is_invariant_)
274 << "Trying to create a store to an invariant IRArray.";
275
276 for (const auto& kind_md_pair : metadata_) {
277 instruction->setMetadata(kind_md_pair.first, kind_md_pair.second);
278 }
279 }
280
EmitReadArrayElement(const Index & index,llvm::IRBuilder<> * ir_builder,tensorflow::StringPiece name) const281 llvm::Value* IrArray::EmitReadArrayElement(const Index& index,
282 llvm::IRBuilder<>* ir_builder,
283 tensorflow::StringPiece name) const {
284 llvm::Value* element_address =
285 EmitArrayElementAddress(index, ir_builder, name);
286 llvm::LoadInst* load = ir_builder->CreateLoad(element_address);
287 AnnotateLoadStoreInstructionWithMetadata(load);
288 return load;
289 }
290
EmitWriteArrayElement(const Index & index,llvm::Value * value,llvm::IRBuilder<> * ir_builder) const291 void IrArray::EmitWriteArrayElement(const Index& index, llvm::Value* value,
292 llvm::IRBuilder<>* ir_builder) const {
293 llvm::Value* element_address = EmitArrayElementAddress(index, ir_builder);
294 llvm::StoreInst* store = ir_builder->CreateStore(value, element_address);
295 AnnotateLoadStoreInstructionWithMetadata(store);
296 }
297
CastToShape(const Shape & new_shape,llvm::IRBuilder<> * ir_builder) const298 IrArray IrArray::CastToShape(const Shape& new_shape,
299 llvm::IRBuilder<>* ir_builder) const {
300 llvm::Module* module = ir_builder->GetInsertBlock()->getParent()->getParent();
301 llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(new_shape, module);
302 return IrArray(
303 ir_builder->CreatePointerCast(base_ptr_, new_ir_type->getPointerTo()),
304 new_shape);
305 }
306
BumpIndex(const Index & index,int64 which_dimension,int64 addend,llvm::IRBuilder<> * ir_builder)307 /* static */ IrArray::Index IrArray::BumpIndex(const Index& index,
308 int64 which_dimension,
309 int64 addend,
310 llvm::IRBuilder<>* ir_builder) {
311 Index new_index = index;
312 new_index[which_dimension] = ir_builder->CreateAdd(
313 index[which_dimension], ir_builder->getInt64(addend), "", /*HasNUW=*/true,
314 /*HasNSW=*/true);
315 return new_index;
316 }
317
318 } // namespace llvm_ir
319 } // namespace xla
320