1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
17
18 #include <tuple>
19 #include <vector>
20
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/Instructions.h"
23 #include "llvm/IR/Value.h"
24 #include "tensorflow/compiler/xla/layout_util.h"
25 #include "tensorflow/compiler/xla/permutation_util.h"
26 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/types.h"
33
34 namespace xla {
35 namespace llvm_ir {
36
Index(absl::Span<llvm::Value * const> multidim,llvm::Value * linear,const Shape & shape,llvm::Type * index_type)37 IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
38 llvm::Value* linear, const Shape& shape,
39 llvm::Type* index_type)
40 : Index(multidim, shape, index_type) {
41 CHECK_NE(linear, nullptr);
42 linear_ = linear;
43 }
44
Delinearize(std::vector<llvm::Value * > * multidim,llvm::Value * linear,const Shape & shape,llvm::IRBuilder<> * b) const45 void IrArray::Index::Delinearize(std::vector<llvm::Value*>* multidim,
46 llvm::Value* linear, const Shape& shape,
47 llvm::IRBuilder<>* b) const {
48 int64_t divisor = 1;
49 const Layout& layout = shape.layout();
50 for (int64_t i = 0; i < layout.minor_to_major_size(); ++i) {
51 int64_t dimension = layout.minor_to_major(i);
52 int64_t size_of_current_dimension = shape.dimensions(dimension);
53
54 // If i is not the last dimension, compute
55 // (linear_index / divisor) % current_dimension.
56 // If i is the last dimension, we can skip the mod, because we assume that
57 // linear is in bounds.
58 //
59 // TODO(jlebar): We could add bounds checks here and elsewhere in this file,
60 // guarded under some sort of xla-memcheck flag. This might be particularly
61 // useful because cuda-memcheck can't help us much in XLA: Most of our
62 // memory lives in one big allocation, so cuda-memcheck can't detect
63 // out-of-bounds accesses.
64 auto* quot = b->CreateUDiv(linear, GetConstantWithIndexType(divisor));
65 if (i < layout.minor_to_major_size() - 1) {
66 (*multidim)[dimension] = b->CreateURem(
67 quot, GetConstantWithIndexType(size_of_current_dimension));
68 } else {
69 (*multidim)[dimension] = quot;
70 }
71 divisor *= size_of_current_dimension;
72 }
73 }
74
Delinearize(std::vector<llvm::Value * > * multidim,llvm::Value * linear,const Shape & shape,absl::Span<llvm::Value * > dynamic_dims,llvm::IRBuilder<> * b) const75 void IrArray::Index::Delinearize(std::vector<llvm::Value*>* multidim,
76 llvm::Value* linear, const Shape& shape,
77 absl::Span<llvm::Value*> dynamic_dims,
78 llvm::IRBuilder<>* b) const {
79 CHECK_EQ(shape.dimensions_size(), dynamic_dims.size());
80 CHECK_EQ(multidim_.size(), shape.rank());
81 llvm::Value* divisor = GetConstantWithIndexType(1);
82 const Layout& layout = shape.layout();
83 for (int64_t i = 0; i < layout.minor_to_major_size(); ++i) {
84 int64_t dimension = layout.minor_to_major(i);
85
86 // If i is not the last dimension, compute
87 // (linear_index / divisor) % current_dimension.
88 // If i is the last dimension, we can skip the mod, because we assume that
89 // linear is in bounds.
90 auto* quot = b->CreateUDiv(linear, divisor, "quot");
91 if (i < layout.minor_to_major_size() - 1) {
92 (*multidim)[dimension] =
93 b->CreateURem(quot, dynamic_dims[dimension], "dim_value");
94 divisor = b->CreateMul(divisor, dynamic_dims[dimension], "divisor");
95 } else {
96 (*multidim)[dimension] = quot;
97 }
98 }
99 }
100
Index(llvm::Value * linear,const Shape & shape,llvm::IRBuilder<> * b)101 IrArray::Index::Index(llvm::Value* linear, const Shape& shape,
102 llvm::IRBuilder<>* b)
103 : multidim_(shape.rank()),
104 linear_(linear),
105 layout_(shape.layout()),
106 dims_(shape.dimensions().begin(), shape.dimensions().end()) {
107 CHECK_NE(linear, nullptr);
108 index_type_ = linear->getType();
109 CHECK(LayoutUtil::HasLayout(shape))
110 << "Shape " << ShapeUtil::HumanStringWithLayout(shape)
111 << " should have a layout.";
112 Delinearize(&multidim_, linear, shape, b);
113 }
114
Index(llvm::Value * linear,absl::Span<llvm::Value * const> multidim,const Shape & shape,llvm::IRBuilder<> * b)115 IrArray::Index::Index(llvm::Value* linear,
116 absl::Span<llvm::Value* const> multidim,
117 const Shape& shape, llvm::IRBuilder<>* b)
118 : multidim_(shape.rank()),
119 linear_(linear),
120 layout_(shape.layout()),
121 dims_(shape.dimensions().begin(), shape.dimensions().end()) {
122 CHECK_NE(linear, nullptr);
123 index_type_ = linear->getType();
124 CHECK_EQ(multidim.size(), shape.rank());
125 for (auto dim : multidim) {
126 if (dim) {
127 CHECK_EQ(dim->getType(), index_type_);
128 }
129 }
130 CHECK(LayoutUtil::HasLayout(shape))
131 << "Shape " << ShapeUtil::HumanStringWithLayout(shape)
132 << " should have a layout.";
133 Delinearize(&multidim_, linear, shape, b);
134 for (int i = 0; i < multidim.size(); ++i) {
135 if (multidim[i] != nullptr) {
136 multidim_[i] = multidim[i];
137 }
138 }
139 }
140
Index(llvm::Value * linear,const Shape & shape,absl::Span<llvm::Value * > dynamic_dims,llvm::IRBuilder<> * b)141 IrArray::Index::Index(llvm::Value* linear, const Shape& shape,
142 absl::Span<llvm::Value*> dynamic_dims,
143 llvm::IRBuilder<>* b)
144 : multidim_(shape.rank()),
145 linear_(linear),
146 layout_(shape.layout()),
147 dims_(shape.dimensions().begin(), shape.dimensions().end()) {
148 CHECK_NE(linear, nullptr);
149 index_type_ = linear->getType();
150 CHECK(LayoutUtil::HasLayout(shape))
151 << "Shape " << ShapeUtil::HumanStringWithLayout(shape)
152 << " should have a layout.";
153 Delinearize(&multidim_, linear, shape, dynamic_dims, b);
154 }
155
Index(absl::Span<llvm::Value * const> multidim,absl::Span<int64 const> dimensions,llvm::Type * index_type)156 IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
157 absl::Span<int64 const> dimensions,
158 llvm::Type* index_type)
159 : Index(multidim, ShapeUtil::MakeShape(/*arbitrary*/ PRED, dimensions),
160 index_type) {}
161
Index(absl::Span<llvm::Value * const> multidim,const Shape & shape,llvm::Type * index_type)162 IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
163 const Shape& shape, llvm::Type* index_type)
164 : multidim_(multidim.begin(), multidim.end()),
165 linear_(nullptr),
166 layout_(shape.layout()),
167 dims_(shape.dimensions().begin(), shape.dimensions().end()),
168 index_type_(index_type) {
169 CHECK_NE(index_type_, nullptr);
170 CHECK_EQ(shape.dimensions_size(), multidim.size());
171 for (const auto* dim : multidim) {
172 CHECK_NE(dim, nullptr);
173 }
174 CHECK(LayoutUtil::HasLayout(shape))
175 << "Shape " << ShapeUtil::HumanStringWithLayout(shape)
176 << " should have a layout.";
177 }
178
IrArray(llvm::Value * base_ptr,Shape shape)179 IrArray::IrArray(llvm::Value* base_ptr, Shape shape)
180 : base_ptr_(base_ptr), shape_(std::move(shape)) {
181 TF_CHECK_OK(ShapeUtil::ValidateShape(shape));
182 CHECK(base_ptr_->getType()->isPointerTy());
183 int depth = 0;
184 element_type_ =
185 llvm::cast<llvm::PointerType>(base_ptr_->getType())->getElementType();
186 while (llvm::ArrayType* array_type =
187 llvm::dyn_cast<llvm::ArrayType>(element_type_)) {
188 element_type_ = array_type->getElementType();
189 ++depth;
190 }
191
192 if (!shape_.IsArray() || ShapeUtil::IsScalar(shape_)) {
193 DCHECK(depth == 1 || depth == 0) << depth;
194 } else {
195 DCHECK_EQ(depth, shape_.rank()) << shape.ShortDebugString();
196 }
197 }
198
199 // Returns whether the given linear index is valid on the given shape.
LinearValidOnShape(const Shape & a) const200 bool IrArray::Index::LinearValidOnShape(const Shape& a) const {
201 auto b = ShapeUtil::MakeShape(a.element_type(), dims_);
202 *b.mutable_layout() = layout_;
203 return linear_ != nullptr &&
204 ShapeUtil::ElementsIn(a) == ShapeUtil::ElementsIn(b) &&
205 ShapeUtil::ReshapeIsBitcast(a, b);
206 }
207
SourceIndexOfReshape(const Shape & output_shape,const Shape & input_shape,llvm::IRBuilder<> * builder) const208 IrArray::Index IrArray::Index::SourceIndexOfReshape(
209 const Shape& output_shape, const Shape& input_shape,
210 llvm::IRBuilder<>* builder) const {
211 CHECK_EQ(multidim_.size(), output_shape.rank());
212 std::vector<llvm::Value*> source_multidim_index(
213 input_shape.rank(), llvm::UndefValue::get(index_type_));
214 auto trivial_reshape =
215 ShapeUtil::InsertedOrDeleted1SizedDimensions(input_shape, output_shape);
216 if (std::get<0>(trivial_reshape)) {
217 // The 1-sized dimensions which only appear in 'input_shape'.
218 auto deleted_dims_indices = std::get<1>(trivial_reshape);
219 // The 1-sized dimensions which only appear in 'output_shape'.
220 auto inserted_dims_indices = std::get<2>(trivial_reshape);
221
222 // This is a two-way merge of 'deleted_dims_indices' with indexing into
223 // 'source_multidim_index', and a two-way merge of 'inserted_dims_indices'
224 // with indexing into 'multidim_'. When we find a dimension in
225 // 'source_multidim_index' which does not belong to 'deleted_dims_indices',
226 // we retrieve the corresponding value from 'multidim_' (skipping any
227 // indices that appear in 'inserted_dims_indices').
228 for (int64_t i = 0, j = 0, k = 0, l = 0; i < source_multidim_index.size();
229 ++i) {
230 if (j == deleted_dims_indices.size() || deleted_dims_indices[j] > i) {
231 // This is a dimension that was preserved. Take the matching value from
232 // multidim_.
233 while (l < inserted_dims_indices.size() &&
234 inserted_dims_indices[l] == k) {
235 // Skip 1-sized dimensions.
236 ++k;
237 ++l;
238 }
239 source_multidim_index[i] = multidim_[k];
240 ++k;
241 } else {
242 // This is a 1-sized dimension that only appears in the operand.
243 source_multidim_index[i] = GetConstantWithIndexType(0);
244 ++j;
245 }
246 }
247 } else {
248 const auto common_factors =
249 CommonFactors(AsInt64Slice(input_shape.dimensions()),
250 AsInt64Slice(output_shape.dimensions()));
251 // We compute the source indices in each common factor from only the target
252 // indices in the same common factor.
253 for (ssize_t k = common_factors.size() - 2; k >= 0; --k) {
254 absl::Span<int64 const> dimensions =
255 AsInt64Slice(output_shape.dimensions())
256 .subspan(common_factors[k].second,
257 common_factors[k + 1].second - common_factors[k].second);
258 llvm::Value* logical_linear_index =
259 Index(absl::Span<llvm::Value* const>(multidim_).subspan(
260 common_factors[k].second,
261 common_factors[k + 1].second - common_factors[k].second),
262 dimensions, index_type_)
263 .Linearize(dimensions, builder);
264 // Delinearizes logical_linear_index for the source array in row-major
265 // collapsed order. The first rank-1 indices are the remainder of the
266 // linear index by each dimension size.
267 for (int64_t i = common_factors[k + 1].first - 1;
268 i >= common_factors[k].first; --i) {
269 llvm::Value* divisor =
270 GetConstantWithIndexType(input_shape.dimensions(i));
271 if (input_shape.dimensions(i) == 1) {
272 source_multidim_index[i] = GetConstantWithIndexType(0);
273 } else if (i == common_factors[k].first) {
274 source_multidim_index[i] = logical_linear_index;
275 } else {
276 source_multidim_index[i] =
277 builder->CreateURem(logical_linear_index, divisor);
278 }
279 logical_linear_index =
280 builder->CreateUDiv(logical_linear_index, divisor);
281 }
282 }
283 }
284
285 if (linear() != nullptr && LayoutUtil::HasLayout(input_shape) &&
286 LayoutUtil::HasLayout(output_shape) &&
287 ShapeUtil::ReshapeIsBitcast(input_shape, output_shape)) {
288 return Index(source_multidim_index, linear(), input_shape, index_type_);
289 }
290 return Index(source_multidim_index, input_shape, index_type_);
291 }
292
SourceIndexOfSlice(const Shape & operand_shape,absl::Span<const int64> starts,absl::Span<const int64> strides,llvm::IRBuilder<> * builder) const293 IrArray::Index IrArray::Index::SourceIndexOfSlice(
294 const Shape& operand_shape, absl::Span<const int64> starts,
295 absl::Span<const int64> strides, llvm::IRBuilder<>* builder) const {
296 std::vector<llvm::Value*> source_multi_index(multidim_.size());
297 for (int i = 0; i < multidim_.size(); ++i) {
298 int64_t stride = strides[i];
299 if (stride != 1) {
300 source_multi_index[i] = builder->CreateAdd(
301 builder->CreateMul(multidim_[i], GetConstantWithIndexType(stride)),
302 GetConstantWithIndexType(starts[i]));
303 } else {
304 source_multi_index[i] =
305 builder->CreateAdd(multidim_[i], GetConstantWithIndexType(starts[i]));
306 }
307 }
308 return Index(source_multi_index, operand_shape, index_type_);
309 }
310
SourceIndexOfTranspose(const Shape & shape,const Shape & operand_shape,absl::Span<const int64> dimension_mapping) const311 IrArray::Index IrArray::Index::SourceIndexOfTranspose(
312 const Shape& shape, const Shape& operand_shape,
313 absl::Span<const int64> dimension_mapping) const {
314 std::vector<llvm::Value*> operand_multidim_index =
315 PermuteInverse(multidim(), dimension_mapping);
316
317 if (linear() != nullptr && LayoutUtil::HasLayout(operand_shape) &&
318 LayoutUtil::HasLayout(shape) &&
319 ShapeUtil::TransposeIsBitcast(operand_shape, shape, dimension_mapping)) {
320 return Index(operand_multidim_index, linear(), operand_shape, index_type_);
321 }
322
323 return Index(operand_multidim_index, operand_shape, index_type_);
324 }
325
SourceIndexOfBitcast(const Shape & shape,const Shape & operand_shape,llvm::IRBuilder<> * builder) const326 IrArray::Index IrArray::Index::SourceIndexOfBitcast(
327 const Shape& shape, const Shape& operand_shape,
328 llvm::IRBuilder<>* builder) const {
329 CHECK(LayoutUtil::HasLayout(shape) && LayoutUtil::HasLayout(operand_shape));
330
331 // In case the bitcast is just a reshape, we can use SourceIndexOfReshape()
332 // instead. This will reuse linear() if possible, so we don't have to build a
333 // new 'linear_index'.
334 if (ShapeUtil::ReshapeIsBitcast(operand_shape, shape)) {
335 return SourceIndexOfReshape(shape, operand_shape, builder);
336 }
337
338 // If we have a linear index, we can definitely use it because we know the
339 // operation is a bitcast. This will recompute the multi-dimensional index for
340 // the operand based on the linear index.
341 if (linear() != nullptr) {
342 return Index(linear(), operand_shape, builder);
343 }
344
345 // First linearize the index coming from the output of the bitcast. We want
346 // the physical index of the element in the buffer. This is like Linearize,
347 // but takes the layout into account.
348 int64_t scale = 1;
349 llvm::Value* linear_index = GetConstantWithIndexType(0);
350 for (auto dimension : LayoutUtil::MinorToMajor(shape)) {
351 linear_index = builder->CreateAdd(
352 linear_index,
353 builder->CreateMul(multidim_[dimension],
354 GetConstantWithIndexType(scale), "",
355 /*HasNUW=*/true, /*HasNSW=*/true),
356 "", /*HasNUW=*/true, /*HasNSW=*/true);
357 scale *= shape.dimensions(dimension);
358 }
359
360 return Index(linear_index, operand_shape, builder);
361 }
362
SourceIndexOfBroadcast(const Shape & shape,const Shape & operand_shape,absl::Span<const int64> dimension_mapping,llvm::IRBuilder<> * builder) const363 IrArray::Index IrArray::Index::SourceIndexOfBroadcast(
364 const Shape& shape, const Shape& operand_shape,
365 absl::Span<const int64> dimension_mapping,
366 llvm::IRBuilder<>* builder) const {
367 int64_t rank = operand_shape.rank();
368 std::vector<llvm::Value*> source_index(rank);
369 for (int64_t i = 0; i < rank; ++i) {
370 source_index[i] = multidim_[dimension_mapping[i]];
371 }
372 if (linear_ == nullptr || !LayoutUtil::HasLayout(operand_shape) ||
373 !LayoutUtil::HasLayout(shape) || rank == 1) {
374 return Index(source_index, operand_shape, index_type_);
375 }
376 // High-level idea: we can reuse the linear index if the broadcasted
377 // dimensions are contiguous, and this part of the operation is a bitcast.
378 // The other dimensions can be masked out with a div and a mod operation.
379 std::vector<int64> logical_to_physical =
380 LayoutUtil::MakeLogicalToPhysical(shape.layout());
381 int64_t output_rank = shape.rank();
382 // The minimum physical dimension that is broadcasted.
383 int64_t min_broadcasted_dimension = output_rank;
384 // The maximum physical dimension that is broadcasted.
385 int64_t max_broadcasted_dimension = -1;
386 for (int64_t i = 0; i < rank; ++i) {
387 int64_t physical_dim = logical_to_physical[dimension_mapping[i]];
388 min_broadcasted_dimension =
389 std::min(min_broadcasted_dimension, physical_dim);
390 max_broadcasted_dimension =
391 std::max(max_broadcasted_dimension, physical_dim);
392 }
393 bool contiguous_broadcast_dimensions =
394 max_broadcasted_dimension - min_broadcasted_dimension == rank - 1;
395 if (!contiguous_broadcast_dimensions) {
396 return Index(source_index, operand_shape, index_type_);
397 }
398 // Check if the mapped dimensions are a bitcast.
399 std::vector<int64> operand_logical_to_physical =
400 LayoutUtil::MakeLogicalToPhysical(operand_shape.layout());
401 for (int64_t i = 0; i < rank; ++i) {
402 if (operand_logical_to_physical[i] !=
403 logical_to_physical[dimension_mapping[i]] - min_broadcasted_dimension) {
404 return Index(source_index, operand_shape, index_type_);
405 }
406 }
407 llvm::Value* linear = linear_;
408 int64_t divisor = 1;
409 for (int64_t i = max_broadcasted_dimension + 1; i < output_rank; ++i) {
410 divisor *= shape.dimensions(LayoutUtil::Major(shape.layout(), i));
411 }
412 if (divisor > 1) {
413 linear = builder->CreateUDiv(linear, GetConstantWithIndexType(divisor));
414 }
415 if (min_broadcasted_dimension > 0) {
416 int64_t mod = 1;
417 for (int64_t i = min_broadcasted_dimension; i <= max_broadcasted_dimension;
418 ++i) {
419 mod *= shape.dimensions(LayoutUtil::Major(shape.layout(), i));
420 }
421 linear = builder->CreateURem(linear, GetConstantWithIndexType(mod));
422 }
423 return Index(source_index, linear, operand_shape, index_type_);
424 }
425
Linearize(absl::Span<const int64> dimensions,llvm::IRBuilder<> * builder) const426 llvm::Value* IrArray::Index::Linearize(absl::Span<const int64> dimensions,
427 llvm::IRBuilder<>* builder) const {
428 // Each dimension is multiplied by the product of the sizes of all
429 // earlier dimensions and added to the accumulator logical_linear_index.
430 CHECK_EQ(size(), dimensions.size());
431 llvm::Value* logical_linear_index = GetConstantWithIndexType(0);
432 int64_t multiplier = 1;
433 for (ssize_t i = size() - 1; i >= 0; --i) {
434 llvm::Value* addend =
435 builder->CreateMul((*this)[i], GetConstantWithIndexType(multiplier), "",
436 /*HasNUW=*/true, /*HasNSW=*/true);
437 addend = builder->CreateZExtOrTrunc(addend, index_type_);
438 logical_linear_index = builder->CreateAdd(logical_linear_index, addend, "",
439 /*HasNUW=*/true, /*HasNSW=*/true);
440 multiplier *= dimensions[i];
441 }
442 return logical_linear_index;
443 }
444
Linearize(const std::vector<llvm::Value * > & dynamic_dims,llvm::IRBuilder<> * builder) const445 llvm::Value* IrArray::Index::Linearize(
446 const std::vector<llvm::Value*>& dynamic_dims,
447 llvm::IRBuilder<>* builder) const {
448 // Each dimension is multiplied by the product of the sizes of all
449 // earlier dimensions and added to the accumulator logical_linear_index.
450 CHECK_EQ(size(), dynamic_dims.size());
451 llvm::Value* logical_linear_index = GetConstantWithIndexType(0);
452 llvm::Value* multiplier = GetConstantWithIndexType(1);
453 for (ssize_t i = size() - 1; i >= 0; --i) {
454 llvm::Value* addend = builder->CreateMul((*this)[i], multiplier, "",
455 /*HasNUW=*/true, /*HasNSW=*/true);
456 addend = builder->CreateZExtOrTrunc(addend, index_type_);
457 logical_linear_index = builder->CreateAdd(logical_linear_index, addend, "",
458 /*HasNUW=*/true, /*HasNSW=*/true);
459 if (i) {
460 multiplier = builder->CreateMul(multiplier, dynamic_dims[i],
461 /*Name=*/"multiplier");
462 }
463 }
464 return logical_linear_index;
465 }
466
EmitArrayElementAddress(const IrArray::Index & index,llvm::IRBuilder<> * b,absl::string_view name,bool use_linear_index) const467 llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index,
468 llvm::IRBuilder<>* b,
469 absl::string_view name,
470 bool use_linear_index) const {
471 if (ShapeUtil::IsScalar(shape_)) {
472 // Special handling of scalars: a scalar pretends to have the same value for
473 // every index, thus effectively implementing broadcasting of its value
474 // over higher-rank arrays.
475 return base_ptr_;
476 }
477 CHECK_EQ(index.size(), shape_.rank());
478 CHECK(index.ShapeIsCompatible(shape_));
479
480 if (use_linear_index && index.LinearValidOnShape(shape_)) {
481 llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
482 return b->CreateInBoundsGEP(
483 b->CreateBitCast(base_ptr_,
484 PrimitiveTypeToIrType(shape_.element_type(), module)
485 ->getPointerTo()),
486 {index.linear()}, llvm_ir::AsStringRef(name));
487 }
488
489 std::vector<llvm::Value*> actual_index;
490 for (int64_t i = 0; i < index.size(); ++i) {
491 // When dimension i is of size 1, LLVM optimization is able to replace
492 // index[i] with 0. However, setting index[i] to 0 here still allows LLVM to
493 // produce better code in some cases.
494 auto dim = shape_.dimensions(i);
495 actual_index.push_back(
496 dim == 1 ? llvm::ConstantInt::get(index[i]->getType(), 0) : index[i]);
497 }
498
499 // "base_ptr_" has the type of "<ir_type_for_its_shape>*"
500 // (e.g. [3 x [2 x float]]*). Therefore, the address of the indexed element
501 // should be computed by
502 //
503 // getelementptr base_ptr_, 0, most major index, ..., most minor index
504 CHECK_GT(index.size(), 0);
505 std::vector<llvm::Value*> gep_indices(
506 1, llvm::ConstantInt::get(index[0]->getType(), 0));
507 for (int64_t i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) {
508 int64_t dimension = LayoutUtil::Major(shape_.layout(), i);
509 gep_indices.push_back(actual_index[dimension]);
510 }
511 return b->CreateInBoundsGEP(base_ptr_, gep_indices,
512 llvm_ir::AsStringRef(name));
513 }
514
AnnotateLoadStoreInstructionWithMetadata(llvm::Instruction * instruction) const515 void IrArray::AnnotateLoadStoreInstructionWithMetadata(
516 llvm::Instruction* instruction) const {
517 CHECK(llvm::isa<llvm::LoadInst>(instruction) ||
518 llvm::isa<llvm::StoreInst>(instruction));
519 CHECK(!llvm::isa<llvm::StoreInst>(instruction) || !is_invariant_)
520 << "Trying to create a store to an invariant IRArray.";
521
522 for (const auto& kind_md_pair : metadata_) {
523 instruction->setMetadata(kind_md_pair.first, kind_md_pair.second);
524 }
525 }
526
EmitReadArrayElement(const Index & index,llvm::IRBuilder<> * b,absl::string_view name,bool use_linear_index) const527 llvm::Value* IrArray::EmitReadArrayElement(const Index& index,
528 llvm::IRBuilder<>* b,
529 absl::string_view name,
530 bool use_linear_index) const {
531 llvm::Value* element_address =
532 EmitArrayElementAddress(index, b, name, use_linear_index);
533 llvm::LoadInst* load = b->CreateLoad(element_address, name.data());
534 AnnotateLoadStoreInstructionWithMetadata(load);
535 return load;
536 }
537
EmitWriteArrayElement(const Index & index,llvm::Value * value,llvm::IRBuilder<> * b,bool use_linear_index) const538 void IrArray::EmitWriteArrayElement(const Index& index, llvm::Value* value,
539 llvm::IRBuilder<>* b,
540 bool use_linear_index) const {
541 llvm::Value* element_address =
542 EmitArrayElementAddress(index, b, "", use_linear_index);
543 llvm::StoreInst* store = b->CreateStore(value, element_address);
544 AnnotateLoadStoreInstructionWithMetadata(store);
545 }
546
CastToShape(const Shape & new_shape,llvm::IRBuilder<> * b) const547 IrArray IrArray::CastToShape(const Shape& new_shape,
548 llvm::IRBuilder<>* b) const {
549 llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
550 llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(new_shape, module);
551 IrArray new_irarray(
552 b->CreatePointerCast(base_ptr_, new_ir_type->getPointerTo()), new_shape);
553 new_irarray.metadata_ = metadata_;
554 return new_irarray;
555 }
556
ShapeIsCompatible(const Shape & a,const Shape & b)557 bool IrArray::Index::ShapeIsCompatible(const Shape& a, const Shape& b) {
558 // Compute strides for two sides of the comparison. Sometimes different shapes
559 // give the same strides:
560 // [10, 20, 30, 1]{3,2,1,0} vs [10, 20, 1, 30]{3,2,1,0}
561 // which should be considered compatible.
562 const auto get_strides = [](const Shape& shape) {
563 int rank = shape.dimensions().size();
564 int64_t stride = 1;
565 std::vector<int64> strides;
566 for (int i = 0; i < rank; i++) {
567 auto dim = shape.dimensions(shape.layout().minor_to_major(i));
568 if (dim != 1) {
569 stride *= dim;
570 strides.push_back(stride);
571 }
572 }
573 return strides;
574 };
575
576 return get_strides(a) == get_strides(b);
577 }
578
579 } // namespace llvm_ir
580 } // namespace xla
581