1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
17
18 #include "llvm/IR/Constants.h"
19 #include "llvm/IR/Instructions.h"
20 #include "tensorflow/compiler/xla/layout_util.h"
21 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/compiler/xla/statusor.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/types.h"
28
29 namespace xla {
30 namespace llvm_ir {
31
Index(absl::Span<llvm::Value * const> multidim,llvm::Value * linear,const Shape & shape,llvm::Type * index_type)32 IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
33 llvm::Value* linear, const Shape& shape,
34 llvm::Type* index_type)
35 : Index(multidim, shape, index_type) {
36 CHECK_NE(linear, nullptr);
37 linear_ = linear;
38 }
39
Delinearize(std::vector<llvm::Value * > * multidim,llvm::Value * linear,const Shape & shape,llvm::IRBuilder<> * b) const40 void IrArray::Index::Delinearize(std::vector<llvm::Value*>* multidim,
41 llvm::Value* linear, const Shape& shape,
42 llvm::IRBuilder<>* b) const {
43 int64 divisor = 1;
44 const Layout& layout = shape.layout();
45 for (int64 i = 0; i < layout.minor_to_major_size(); ++i) {
46 int64 dimension = layout.minor_to_major(i);
47 int64 size_of_current_dimension = shape.dimensions(dimension);
48
49 // If i is not the last dimension, compute
50 // (linear_index / divisor) % current_dimension.
51 // If i is the last dimension, we can skip the mod, because we assume that
52 // linear is in bounds.
53 //
54 // TODO(jlebar): We could add bounds checks here and elsewhere in this file,
55 // guarded under some sort of xla-memcheck flag. This might be particularly
56 // useful because cuda-memcheck can't help us much in XLA: Most of our
57 // memory lives in one big allocation, so cuda-memcheck can't detect
58 // out-of-bounds accesses.
59 auto* quot = b->CreateUDiv(linear, GetConstantWithIndexType(divisor));
60 if (i < layout.minor_to_major_size() - 1) {
61 (*multidim)[dimension] = b->CreateURem(
62 quot, GetConstantWithIndexType(size_of_current_dimension));
63 } else {
64 (*multidim)[dimension] = quot;
65 }
66 divisor *= size_of_current_dimension;
67 }
68 }
69
Index(llvm::Value * linear,const Shape & shape,llvm::IRBuilder<> * b)70 IrArray::Index::Index(llvm::Value* linear, const Shape& shape,
71 llvm::IRBuilder<>* b)
72 : multidim_(shape.rank()),
73 linear_(linear),
74 layout_(shape.layout()),
75 dims_(shape.dimensions().begin(), shape.dimensions().end()) {
76 CHECK_NE(linear, nullptr);
77 index_type_ = linear->getType();
78 CHECK(LayoutUtil::HasLayout(shape))
79 << "Shape " << ShapeUtil::HumanStringWithLayout(shape)
80 << " should have a layout.";
81 Delinearize(&multidim_, linear, shape, b);
82 }
83
Index(absl::Span<llvm::Value * const> multidim,absl::Span<int64 const> dimensions,llvm::Type * index_type)84 IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
85 absl::Span<int64 const> dimensions,
86 llvm::Type* index_type)
87 : Index(multidim, ShapeUtil::MakeShape(/*arbitrary*/ PRED, dimensions),
88 index_type) {}
89
Index(absl::Span<llvm::Value * const> multidim,const Shape & shape,llvm::Type * index_type)90 IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
91 const Shape& shape, llvm::Type* index_type)
92 : multidim_(multidim.begin(), multidim.end()),
93 linear_(nullptr),
94 layout_(shape.layout()),
95 dims_(shape.dimensions().begin(), shape.dimensions().end()),
96 index_type_(index_type) {
97 CHECK_NE(index_type_, nullptr);
98 CHECK_EQ(shape.dimensions_size(), multidim.size());
99 for (const auto* dim : multidim) {
100 CHECK_NE(dim, nullptr);
101 }
102 CHECK(LayoutUtil::HasLayout(shape))
103 << "Shape " << ShapeUtil::HumanStringWithLayout(shape)
104 << " should have a layout.";
105 }
106
IrArray(llvm::Value * base_ptr,Shape shape)107 IrArray::IrArray(llvm::Value* base_ptr, Shape shape)
108 : base_ptr_(base_ptr), shape_(std::move(shape)) {
109 TF_CHECK_OK(ShapeUtil::ValidateShape(shape));
110 CHECK(base_ptr_->getType()->isPointerTy());
111 int depth = 0;
112 element_type_ =
113 llvm::cast<llvm::PointerType>(base_ptr_->getType())->getElementType();
114 while (llvm::ArrayType* array_type =
115 llvm::dyn_cast<llvm::ArrayType>(element_type_)) {
116 element_type_ = array_type->getElementType();
117 ++depth;
118 }
119
120 if (!shape_.IsArray() || ShapeUtil::IsScalar(shape_)) {
121 DCHECK(depth == 1 || depth == 0) << depth;
122 } else {
123 DCHECK_EQ(depth, shape_.rank()) << shape.ShortDebugString();
124 }
125 }
126
127 // Returns whether the given linear index is valid on the given shape.
LinearValidOnShape(const Shape & a) const128 bool IrArray::Index::LinearValidOnShape(const Shape& a) const {
129 auto b = ShapeUtil::MakeShape(a.element_type(), dims_);
130 *b.mutable_layout() = layout_;
131 return linear_ != nullptr &&
132 ShapeUtil::ElementsIn(a) == ShapeUtil::ElementsIn(b) &&
133 ShapeUtil::ReshapeIsBitcast(a, b);
134 }
135
SourceIndexOfReshape(const Shape & output_shape,const Shape & input_shape,llvm::IRBuilder<> * builder) const136 IrArray::Index IrArray::Index::SourceIndexOfReshape(
137 const Shape& output_shape, const Shape& input_shape,
138 llvm::IRBuilder<>* builder) const {
139 CHECK_EQ(multidim_.size(), output_shape.rank());
140 const auto common_factors =
141 CommonFactors(AsInt64Slice(input_shape.dimensions()),
142 AsInt64Slice(output_shape.dimensions()));
143 std::vector<llvm::Value*> source_multidim_index(
144 input_shape.rank(), llvm::UndefValue::get(index_type_));
145 // We compute the source indices in each common factor from only the target
146 // indices in the same common factor.
147 for (ssize_t k = common_factors.size() - 2; k >= 0; --k) {
148 absl::Span<int64 const> dimensions =
149 AsInt64Slice(output_shape.dimensions())
150 .subspan(common_factors[k].second,
151 common_factors[k + 1].second - common_factors[k].second);
152 llvm::Value* logical_linear_index =
153 Index(absl::Span<llvm::Value* const>(multidim_).subspan(
154 common_factors[k].second,
155 common_factors[k + 1].second - common_factors[k].second),
156 dimensions, index_type_)
157 .Linearize(dimensions, builder);
158 // Delinearizes logical_linear_index for the source array in row-major
159 // collapsed order. The first rank-1 indices are the remainder of the
160 // linear index by each dimension size.
161 for (int64 i = common_factors[k + 1].first - 1;
162 i >= common_factors[k].first; --i) {
163 llvm::Value* divisor =
164 GetConstantWithIndexType(input_shape.dimensions(i));
165 if (input_shape.dimensions(i) == 1) {
166 source_multidim_index[i] = GetConstantWithIndexType(0);
167 } else if (i == common_factors[k].first) {
168 source_multidim_index[i] = logical_linear_index;
169 } else {
170 source_multidim_index[i] =
171 builder->CreateURem(logical_linear_index, divisor);
172 }
173 logical_linear_index = builder->CreateUDiv(logical_linear_index, divisor);
174 }
175 }
176
177 if (linear() != nullptr && LayoutUtil::HasLayout(input_shape) &&
178 LayoutUtil::HasLayout(output_shape) &&
179 ShapeUtil::ReshapeIsBitcast(input_shape, output_shape)) {
180 return Index(source_multidim_index, linear(), input_shape, index_type_);
181 }
182 return Index(source_multidim_index, input_shape, index_type_);
183 }
184
SourceIndexOfSlice(const Shape & operand_shape,absl::Span<const int64> starts,absl::Span<const int64> strides,llvm::IRBuilder<> * builder) const185 IrArray::Index IrArray::Index::SourceIndexOfSlice(
186 const Shape& operand_shape, absl::Span<const int64> starts,
187 absl::Span<const int64> strides, llvm::IRBuilder<>* builder) const {
188 std::vector<llvm::Value*> source_multi_index(multidim_.size());
189 for (int i = 0; i < multidim_.size(); ++i) {
190 int64 stride = strides[i];
191 auto type = multidim_[i]->getType();
192
193 if (stride != 1) {
194 source_multi_index[i] = builder->CreateAdd(
195 builder->CreateMul(multidim_[i],
196 llvm::ConstantInt::get(type, stride)),
197 llvm::ConstantInt::get(type, starts[i]));
198 } else {
199 source_multi_index[i] = builder->CreateAdd(
200 multidim_[i], llvm::ConstantInt::get(type, starts[i]));
201 }
202 }
203 return Index(source_multi_index, operand_shape, index_type_);
204 }
205
SourceIndexOfTranspose(const Shape & shape,const Shape & operand_shape,absl::Span<const int64> dimension_mapping,llvm::IRBuilder<> * builder) const206 IrArray::Index IrArray::Index::SourceIndexOfTranspose(
207 const Shape& shape, const Shape& operand_shape,
208 absl::Span<const int64> dimension_mapping,
209 llvm::IRBuilder<>* builder) const {
210 std::vector<llvm::Value*> operand_multidim_index =
211 Permute(dimension_mapping, multidim());
212
213 if (linear() != nullptr && LayoutUtil::HasLayout(operand_shape) &&
214 LayoutUtil::HasLayout(shape) &&
215 ShapeUtil::TransposeIsBitcast(operand_shape, shape, dimension_mapping)) {
216 return Index(operand_multidim_index, linear(), operand_shape, index_type_);
217 }
218
219 return Index(operand_multidim_index, operand_shape, index_type_);
220 }
221
SourceIndexOfBitcast(const Shape & shape,const Shape & operand_shape,llvm::IRBuilder<> * builder) const222 IrArray::Index IrArray::Index::SourceIndexOfBitcast(
223 const Shape& shape, const Shape& operand_shape,
224 llvm::IRBuilder<>* builder) const {
225 CHECK(LayoutUtil::HasLayout(shape) && LayoutUtil::HasLayout(operand_shape));
226 // In case the bitcast is just a reshape, we can use SourceIndexOfReshape()
227 // instead. This will reuse linear() if possible, so we don't have to build a
228 // new 'linear_index'.
229 if (ShapeUtil::ReshapeIsBitcast(operand_shape, shape)) {
230 return SourceIndexOfReshape(shape, operand_shape, builder);
231 }
232
233 // First linearize the index coming from the output of the bitcast. We want
234 // the physical index of the element in the buffer. This is like Linearize,
235 // but takes the layout into account.
236 int64 scale = 1;
237 llvm::Value* linear_index = GetConstantWithIndexType(0);
238 for (auto dimension : LayoutUtil::MinorToMajor(shape)) {
239 linear_index = builder->CreateAdd(
240 linear_index,
241 builder->CreateMul(multidim_[dimension],
242 GetConstantWithIndexType(scale), "",
243 /*HasNUW=*/true, /*HasNSW=*/true),
244 "", /*HasNUW=*/true, /*HasNSW=*/true);
245 scale *= shape.dimensions(dimension);
246 }
247
248 return Index(linear_index, operand_shape, builder);
249 }
250
SourceIndexOfBroadcast(const Shape & shape,const Shape & operand_shape,absl::Span<const int64> dimension_mapping,llvm::IRBuilder<> * builder) const251 IrArray::Index IrArray::Index::SourceIndexOfBroadcast(
252 const Shape& shape, const Shape& operand_shape,
253 absl::Span<const int64> dimension_mapping,
254 llvm::IRBuilder<>* builder) const {
255 int64 rank = operand_shape.rank();
256 std::vector<llvm::Value*> source_index(rank);
257 for (int64 i = 0; i < rank; ++i) {
258 source_index[i] = multidim_[dimension_mapping[i]];
259 }
260 if (linear_ == nullptr || !LayoutUtil::HasLayout(operand_shape) ||
261 !LayoutUtil::HasLayout(shape)) {
262 return Index(source_index, operand_shape, index_type_);
263 }
264 // High-level idea: we can reuse the linear index if the broadcasted
265 // dimensions are contiguous, and this part of the operation is a bitcast.
266 // The other dimensions can be masked out with a div and a mod operation.
267 std::vector<int64> logical_to_physical =
268 LayoutUtil::MakeLogicalToPhysical(shape.layout());
269 int64 output_rank = shape.rank();
270 // The minimum physical dimension that is broadcasted.
271 int64 min_broadcasted_dimension = output_rank;
272 // The maximum physical dimension that is broadcasted.
273 int64 max_broadcasted_dimension = -1;
274 for (int64 i = 0; i < rank; ++i) {
275 int64 physical_dim = logical_to_physical[dimension_mapping[i]];
276 min_broadcasted_dimension =
277 std::min(min_broadcasted_dimension, physical_dim);
278 max_broadcasted_dimension =
279 std::max(max_broadcasted_dimension, physical_dim);
280 }
281 bool contiguous_broadcast_dimensions =
282 max_broadcasted_dimension - min_broadcasted_dimension == rank - 1;
283 if (!contiguous_broadcast_dimensions) {
284 return Index(source_index, operand_shape, index_type_);
285 }
286 // Check if the mapped dimensions are a bitcast.
287 std::vector<int64> operand_logical_to_physical =
288 LayoutUtil::MakeLogicalToPhysical(operand_shape.layout());
289 for (int64 i = 0; i < rank; ++i) {
290 if (operand_logical_to_physical[i] !=
291 logical_to_physical[dimension_mapping[i]] - min_broadcasted_dimension) {
292 return Index(source_index, operand_shape, index_type_);
293 }
294 }
295 llvm::Value* linear = linear_;
296 int64 divisor = 1;
297 for (int64 i = max_broadcasted_dimension + 1; i < output_rank; ++i) {
298 divisor *= shape.dimensions(LayoutUtil::Major(shape.layout(), i));
299 }
300 if (divisor > 1) {
301 linear = builder->CreateUDiv(linear, GetConstantWithIndexType(divisor));
302 }
303 if (min_broadcasted_dimension > 0) {
304 int64 mod = 1;
305 for (int64 i = min_broadcasted_dimension; i <= max_broadcasted_dimension;
306 ++i) {
307 mod *= shape.dimensions(LayoutUtil::Major(shape.layout(), i));
308 }
309 linear = builder->CreateURem(linear, GetConstantWithIndexType(mod));
310 }
311 return Index(source_index, linear, operand_shape, index_type_);
312 }
313
Linearize(absl::Span<const int64> dimensions,llvm::IRBuilder<> * builder) const314 llvm::Value* IrArray::Index::Linearize(absl::Span<const int64> dimensions,
315 llvm::IRBuilder<>* builder) const {
316 // Each dimension is multiplied by the product of the sizes of all
317 // earlier dimensions and added to the accumulator logical_linear_index.
318 CHECK_EQ(size(), dimensions.size());
319 llvm::Value* logical_linear_index = GetConstantWithIndexType(0);
320 int64 multiplier = 1;
321 for (ssize_t i = size() - 1; i >= 0; --i) {
322 llvm::Value* addend =
323 builder->CreateMul((*this)[i], GetConstantWithIndexType(multiplier), "",
324 /*HasNUW=*/true, /*HasNSW=*/true);
325 addend = builder->CreateZExtOrTrunc(addend, index_type_);
326 logical_linear_index = builder->CreateAdd(logical_linear_index, addend, "",
327 /*HasNUW=*/true, /*HasNSW=*/true);
328 multiplier *= dimensions[i];
329 }
330 return logical_linear_index;
331 }
332
EmitArrayElementAddress(const IrArray::Index & index,llvm::IRBuilder<> * b,absl::string_view name,bool use_linear_index) const333 llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index,
334 llvm::IRBuilder<>* b,
335 absl::string_view name,
336 bool use_linear_index) const {
337 if (ShapeUtil::IsScalar(shape_)) {
338 // Special handling of scalars: a scalar pretends to have the same value for
339 // every index, thus effectively implementing broadcasting of its value
340 // over higher-rank arrays.
341 return base_ptr_;
342 }
343 CHECK_EQ(index.size(), shape_.rank());
344 CHECK(index.ShapeIsCompatible(shape_));
345
346 if (use_linear_index && index.LinearValidOnShape(shape_)) {
347 llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
348 return b->CreateInBoundsGEP(
349 b->CreateBitCast(base_ptr_,
350 PrimitiveTypeToIrType(shape_.element_type(), module)
351 ->getPointerTo()),
352 {index.linear()}, llvm_ir::AsStringRef(name));
353 }
354
355 std::vector<llvm::Value*> actual_index;
356 for (int64 i = 0; i < index.size(); ++i) {
357 // When dimension i is of size 1, LLVM optimization is able to replace
358 // index[i] with 0. However, setting index[i] to 0 here still allows LLVM to
359 // produce better code in some cases.
360 auto dim = shape_.dimensions(i);
361 actual_index.push_back(
362 dim == 1 ? llvm::ConstantInt::get(index[i]->getType(), 0) : index[i]);
363 }
364
365 // "base_ptr_" has the type of "<ir_type_for_its_shape>*"
366 // (e.g. [3 x [2 x float]]*). Therefore, the address of the indexed element
367 // should be computed by
368 //
369 // getelementptr base_ptr_, 0, most major index, ..., most minor index
370 CHECK_GT(index.size(), 0);
371 std::vector<llvm::Value*> gep_indices(
372 1, llvm::ConstantInt::get(index[0]->getType(), 0));
373 for (int64 i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) {
374 int64 dimension = LayoutUtil::Major(shape_.layout(), i);
375 gep_indices.push_back(actual_index[dimension]);
376 }
377 return b->CreateInBoundsGEP(base_ptr_, gep_indices,
378 llvm_ir::AsStringRef(name));
379 }
380
AnnotateLoadStoreInstructionWithMetadata(llvm::Instruction * instruction) const381 void IrArray::AnnotateLoadStoreInstructionWithMetadata(
382 llvm::Instruction* instruction) const {
383 CHECK(llvm::isa<llvm::LoadInst>(instruction) ||
384 llvm::isa<llvm::StoreInst>(instruction));
385 CHECK(!llvm::isa<llvm::StoreInst>(instruction) || !is_invariant_)
386 << "Trying to create a store to an invariant IRArray.";
387
388 for (const auto& kind_md_pair : metadata_) {
389 instruction->setMetadata(kind_md_pair.first, kind_md_pair.second);
390 }
391 }
392
EmitReadArrayElement(const Index & index,llvm::IRBuilder<> * b,absl::string_view name,bool use_linear_index) const393 llvm::Value* IrArray::EmitReadArrayElement(const Index& index,
394 llvm::IRBuilder<>* b,
395 absl::string_view name,
396 bool use_linear_index) const {
397 llvm::Value* element_address =
398 EmitArrayElementAddress(index, b, name, use_linear_index);
399 llvm::LoadInst* load = b->CreateLoad(element_address);
400 AnnotateLoadStoreInstructionWithMetadata(load);
401 return load;
402 }
403
EmitWriteArrayElement(const Index & index,llvm::Value * value,llvm::IRBuilder<> * b,bool use_linear_index) const404 void IrArray::EmitWriteArrayElement(const Index& index, llvm::Value* value,
405 llvm::IRBuilder<>* b,
406 bool use_linear_index) const {
407 llvm::Value* element_address =
408 EmitArrayElementAddress(index, b, "", use_linear_index);
409 llvm::StoreInst* store = b->CreateStore(value, element_address);
410 AnnotateLoadStoreInstructionWithMetadata(store);
411 }
412
CastToShape(const Shape & new_shape,llvm::IRBuilder<> * b) const413 IrArray IrArray::CastToShape(const Shape& new_shape,
414 llvm::IRBuilder<>* b) const {
415 llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
416 llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(new_shape, module);
417 IrArray new_irarray(
418 b->CreatePointerCast(base_ptr_, new_ir_type->getPointerTo()), new_shape);
419 new_irarray.metadata_ = metadata_;
420 return new_irarray;
421 }
422
423 } // namespace llvm_ir
424 } // namespace xla
425