1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
17
18 #include <optional>
19 #include <tuple>
20 #include <utility>
21 #include <vector>
22
23 #include "llvm/IR/Constants.h"
24 #include "llvm/IR/Instructions.h"
25 #include "llvm/IR/Value.h"
26 #include "tensorflow/compiler/xla/layout_util.h"
27 #include "tensorflow/compiler/xla/permutation_util.h"
28 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_type_conversion_util.h"
29 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/platform/logging.h"
35
36 namespace xla {
37 namespace llvm_ir {
38
Index(absl::Span<llvm::Value * const> multidim,llvm::Value * linear,const Shape & shape,llvm::Type * index_type)39 IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
40 llvm::Value* linear, const Shape& shape,
41 llvm::Type* index_type)
42 : Index(multidim, shape, index_type) {
43 CHECK_NE(linear, nullptr);
44 linear_ = linear;
45 }
46
Delinearize(std::vector<llvm::Value * > * multidim,llvm::Value * linear,const Shape & shape,llvm::IRBuilder<> * b) const47 void IrArray::Index::Delinearize(std::vector<llvm::Value*>* multidim,
48 llvm::Value* linear, const Shape& shape,
49 llvm::IRBuilder<>* b) const {
50 int64_t divisor = 1;
51 const Layout& layout = shape.layout();
52 for (int64_t i = 0; i < layout.minor_to_major_size(); ++i) {
53 int64_t dimension = layout.minor_to_major(i);
54 int64_t size_of_current_dimension = shape.dimensions(dimension);
55
56 // If i is not the last dimension, compute
57 // (linear_index / divisor) % current_dimension.
58 // If i is the last dimension, we can skip the mod, because we assume that
59 // linear is in bounds.
60 //
61 // TODO(jlebar): We could add bounds checks here and elsewhere in this file,
62 // guarded under some sort of xla-memcheck flag. This might be particularly
63 // useful because cuda-memcheck can't help us much in XLA: Most of our
64 // memory lives in one big allocation, so cuda-memcheck can't detect
65 // out-of-bounds accesses.
66 auto* quot = b->CreateUDiv(linear, GetConstantWithIndexType(divisor));
67 if (i < layout.minor_to_major_size() - 1) {
68 (*multidim)[dimension] = b->CreateURem(
69 quot, GetConstantWithIndexType(size_of_current_dimension));
70 } else {
71 (*multidim)[dimension] = quot;
72 }
73 divisor *= size_of_current_dimension;
74 }
75 }
76
Delinearize(std::vector<llvm::Value * > * multidim,llvm::Value * linear,const Shape & shape,absl::Span<llvm::Value * > dynamic_dims,llvm::IRBuilder<> * b) const77 void IrArray::Index::Delinearize(std::vector<llvm::Value*>* multidim,
78 llvm::Value* linear, const Shape& shape,
79 absl::Span<llvm::Value*> dynamic_dims,
80 llvm::IRBuilder<>* b) const {
81 CHECK_EQ(shape.dimensions_size(), dynamic_dims.size());
82 CHECK_EQ(multidim_.size(), shape.rank());
83 llvm::Value* divisor = GetConstantWithIndexType(1);
84 const Layout& layout = shape.layout();
85 for (int64_t i = 0; i < layout.minor_to_major_size(); ++i) {
86 int64_t dimension = layout.minor_to_major(i);
87
88 // If i is not the last dimension, compute
89 // (linear_index / divisor) % current_dimension.
90 // If i is the last dimension, we can skip the mod, because we assume that
91 // linear is in bounds.
92 auto* quot = b->CreateUDiv(linear, divisor, "quot");
93 if (i < layout.minor_to_major_size() - 1) {
94 llvm::Value* casted_dynamic_dim =
95 b->CreateIntCast(dynamic_dims[dimension], quot->getType(),
96 /*isSigned=*/true);
97 (*multidim)[dimension] =
98 b->CreateURem(quot, casted_dynamic_dim, "dim_value");
99 divisor = b->CreateMul(divisor, casted_dynamic_dim, "divisor");
100 } else {
101 (*multidim)[dimension] = quot;
102 }
103 }
104 }
105
Index(llvm::Value * linear,const Shape & shape,llvm::IRBuilder<> * b)106 IrArray::Index::Index(llvm::Value* linear, const Shape& shape,
107 llvm::IRBuilder<>* b)
108 : multidim_(shape.rank()),
109 linear_(linear),
110 layout_(shape.layout()),
111 dims_(shape.dimensions().begin(), shape.dimensions().end()) {
112 CHECK_NE(linear, nullptr);
113 index_type_ = linear->getType();
114 CHECK(LayoutUtil::HasLayout(shape))
115 << "Shape " << ShapeUtil::HumanStringWithLayout(shape)
116 << " should have a layout.";
117 Delinearize(&multidim_, linear, shape, b);
118 }
119
Index(llvm::Value * linear,absl::Span<llvm::Value * const> multidim,const Shape & shape,llvm::IRBuilder<> * b)120 IrArray::Index::Index(llvm::Value* linear,
121 absl::Span<llvm::Value* const> multidim,
122 const Shape& shape, llvm::IRBuilder<>* b)
123 : multidim_(shape.rank()),
124 linear_(linear),
125 layout_(shape.layout()),
126 dims_(shape.dimensions().begin(), shape.dimensions().end()) {
127 CHECK_NE(linear, nullptr);
128 index_type_ = linear->getType();
129 CHECK_EQ(multidim.size(), shape.rank());
130 for (auto dim : multidim) {
131 if (dim) {
132 CHECK_EQ(dim->getType(), index_type_);
133 }
134 }
135 CHECK(LayoutUtil::HasLayout(shape))
136 << "Shape " << ShapeUtil::HumanStringWithLayout(shape)
137 << " should have a layout.";
138 Delinearize(&multidim_, linear, shape, b);
139 for (int i = 0; i < multidim.size(); ++i) {
140 if (multidim[i] != nullptr) {
141 multidim_[i] = multidim[i];
142 }
143 }
144 }
145
Index(llvm::Value * linear,const Shape & shape,absl::Span<llvm::Value * > dynamic_dims,llvm::IRBuilder<> * b)146 IrArray::Index::Index(llvm::Value* linear, const Shape& shape,
147 absl::Span<llvm::Value*> dynamic_dims,
148 llvm::IRBuilder<>* b)
149 : multidim_(shape.rank()),
150 linear_(linear),
151 layout_(shape.layout()),
152 dims_(shape.dimensions().begin(), shape.dimensions().end()) {
153 CHECK_NE(linear, nullptr);
154 index_type_ = linear->getType();
155 CHECK(LayoutUtil::HasLayout(shape))
156 << "Shape " << ShapeUtil::HumanStringWithLayout(shape)
157 << " should have a layout.";
158 Delinearize(&multidim_, linear, shape, dynamic_dims, b);
159 }
160
Index(absl::Span<llvm::Value * const> multidim,absl::Span<int64_t const> dimensions,llvm::Type * index_type)161 IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
162 absl::Span<int64_t const> dimensions,
163 llvm::Type* index_type)
164 : Index(multidim, ShapeUtil::MakeShape(/*arbitrary*/ PRED, dimensions),
165 index_type) {}
166
Index(absl::Span<llvm::Value * const> multidim,const Shape & shape,llvm::Type * index_type)167 IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
168 const Shape& shape, llvm::Type* index_type)
169 : multidim_(multidim.begin(), multidim.end()),
170 linear_(nullptr),
171 layout_(shape.layout()),
172 dims_(shape.dimensions().begin(), shape.dimensions().end()),
173 index_type_(index_type) {
174 CHECK_NE(index_type_, nullptr);
175 CHECK_EQ(shape.dimensions_size(), multidim.size());
176 for (const auto* dim : multidim) {
177 CHECK_NE(dim, nullptr);
178 }
179 CHECK(LayoutUtil::HasLayout(shape))
180 << "Shape " << ShapeUtil::HumanStringWithLayout(shape)
181 << " should have a layout.";
182 }
183
IrArray(llvm::Value * base_ptr,llvm::Type * pointee_type,Shape shape)184 IrArray::IrArray(llvm::Value* base_ptr, llvm::Type* pointee_type, Shape shape)
185 : base_ptr_(base_ptr),
186 pointee_type_(pointee_type),
187 shape_(std::move(shape)) {
188 TF_CHECK_OK(ShapeUtil::ValidateShape(shape));
189 CHECK(base_ptr_->getType()->isPointerTy());
190 CHECK(llvm::cast<llvm::PointerType>(base_ptr_->getType())
191 ->isOpaqueOrPointeeTypeMatches(pointee_type));
192 int depth = 0;
193 element_type_ = pointee_type;
194 while (llvm::ArrayType* array_type =
195 llvm::dyn_cast<llvm::ArrayType>(element_type_)) {
196 element_type_ = array_type->getElementType();
197 ++depth;
198 }
199
200 if (!shape_.IsArray() || ShapeUtil::IsScalar(shape_)) {
201 DCHECK(depth == 1 || depth == 0) << depth;
202 } else {
203 DCHECK_EQ(depth, shape_.rank()) << shape.ShortDebugString();
204 }
205 }
206
207 // Returns whether the given linear index is valid on the given shape.
LinearValidOnShape(const Shape & a) const208 bool IrArray::Index::LinearValidOnShape(const Shape& a) const {
209 auto b = ShapeUtil::MakeShape(a.element_type(), dims_);
210 *b.mutable_layout() = layout_;
211 return linear_ != nullptr &&
212 ShapeUtil::ElementsIn(a) == ShapeUtil::ElementsIn(b) &&
213 ShapeUtil::ReshapeIsBitcast(a, b);
214 }
215
SourceIndexOfReshape(const Shape & output_shape,const Shape & input_shape,llvm::IRBuilder<> * builder) const216 IrArray::Index IrArray::Index::SourceIndexOfReshape(
217 const Shape& output_shape, const Shape& input_shape,
218 llvm::IRBuilder<>* builder) const {
219 CHECK_EQ(multidim_.size(), output_shape.rank());
220 std::vector<llvm::Value*> source_multidim_index(
221 input_shape.rank(), llvm::UndefValue::get(index_type_));
222
223 if (std::optional<ShapeUtil::ShapeEqualityDescriptor> trivial_reshape =
224 ShapeUtil::InsertedOrDeleted1SizedDimensions(input_shape,
225 output_shape)) {
226 // This is a two-way merge of 'deleted_dims_indices' with indexing into
227 // 'source_multidim_index', and a two-way merge of 'inserted_dims_indices'
228 // with indexing into 'multidim_'. When we find a dimension in
229 // 'source_multidim_index' which does not belong to 'deleted_dims_indices',
230 // we retrieve the corresponding value from 'multidim_' (skipping any
231 // indices that appear in 'inserted_dims_indices').
232 for (int64_t i = 0, j = 0, k = 0, l = 0; i < source_multidim_index.size();
233 ++i) {
234 if (j == trivial_reshape->deleted_dimensions.size() ||
235 trivial_reshape->deleted_dimensions[j] > i) {
236 // This is a dimension that was preserved. Take the matching value from
237 // multidim_.
238 while (l < trivial_reshape->inserted_dimensions.size() &&
239 trivial_reshape->inserted_dimensions[l] == k) {
240 // Skip 1-sized dimensions.
241 ++k;
242 ++l;
243 }
244 source_multidim_index[i] = multidim_[k];
245 ++k;
246 } else {
247 // This is a 1-sized dimension that only appears in the operand.
248 source_multidim_index[i] = GetConstantWithIndexType(0);
249 ++j;
250 }
251 }
252 } else {
253 const auto common_factors =
254 CommonFactors(input_shape.dimensions(), output_shape.dimensions());
255 // We compute the source indices in each common factor from only the target
256 // indices in the same common factor.
257 for (ssize_t k = common_factors.size() - 2; k >= 0; --k) {
258 absl::Span<int64_t const> dimensions = output_shape.dimensions().subspan(
259 common_factors[k].second,
260 common_factors[k + 1].second - common_factors[k].second);
261 llvm::Value* logical_linear_index =
262 Index(absl::Span<llvm::Value* const>(multidim_).subspan(
263 common_factors[k].second,
264 common_factors[k + 1].second - common_factors[k].second),
265 dimensions, index_type_)
266 .Linearize(dimensions, builder);
267 // Delinearizes logical_linear_index for the source array in row-major
268 // collapsed order. The first rank-1 indices are the remainder of the
269 // linear index by each dimension size.
270 for (int64_t i = common_factors[k + 1].first - 1;
271 i >= common_factors[k].first; --i) {
272 llvm::Value* divisor =
273 GetConstantWithIndexType(input_shape.dimensions(i));
274 if (input_shape.dimensions(i) == 1) {
275 source_multidim_index[i] = GetConstantWithIndexType(0);
276 } else if (i == common_factors[k].first) {
277 source_multidim_index[i] = logical_linear_index;
278 } else {
279 source_multidim_index[i] =
280 builder->CreateURem(logical_linear_index, divisor);
281 }
282 logical_linear_index =
283 builder->CreateUDiv(logical_linear_index, divisor);
284 }
285 }
286 }
287
288 if (linear() != nullptr && LayoutUtil::HasLayout(input_shape) &&
289 LayoutUtil::HasLayout(output_shape) &&
290 ShapeUtil::ReshapeIsBitcast(input_shape, output_shape)) {
291 return Index(source_multidim_index, linear(), input_shape, index_type_);
292 }
293 return Index(source_multidim_index, input_shape, index_type_);
294 }
295
SourceIndexOfSlice(const Shape & operand_shape,absl::Span<const int64_t> starts,absl::Span<const int64_t> strides,llvm::IRBuilder<> * builder) const296 IrArray::Index IrArray::Index::SourceIndexOfSlice(
297 const Shape& operand_shape, absl::Span<const int64_t> starts,
298 absl::Span<const int64_t> strides, llvm::IRBuilder<>* builder) const {
299 std::vector<llvm::Value*> source_multi_index(multidim_.size());
300 for (int i = 0; i < multidim_.size(); ++i) {
301 int64_t stride = strides[i];
302 if (stride != 1) {
303 source_multi_index[i] = builder->CreateAdd(
304 builder->CreateMul(multidim_[i], GetConstantWithIndexType(stride)),
305 GetConstantWithIndexType(starts[i]));
306 } else {
307 source_multi_index[i] =
308 builder->CreateAdd(multidim_[i], GetConstantWithIndexType(starts[i]));
309 }
310 }
311 return Index(source_multi_index, operand_shape, index_type_);
312 }
313
SourceIndexOfTranspose(const Shape & shape,const Shape & operand_shape,absl::Span<const int64_t> dimension_mapping) const314 IrArray::Index IrArray::Index::SourceIndexOfTranspose(
315 const Shape& shape, const Shape& operand_shape,
316 absl::Span<const int64_t> dimension_mapping) const {
317 std::vector<llvm::Value*> operand_multidim_index =
318 PermuteInverse(multidim(), dimension_mapping);
319
320 if (linear() != nullptr && LayoutUtil::HasLayout(operand_shape) &&
321 LayoutUtil::HasLayout(shape) &&
322 ShapeUtil::TransposeIsBitcast(operand_shape, shape, dimension_mapping)) {
323 return Index(operand_multidim_index, linear(), operand_shape, index_type_);
324 }
325
326 return Index(operand_multidim_index, operand_shape, index_type_);
327 }
328
SourceIndexOfBitcast(const Shape & shape,const Shape & operand_shape,llvm::IRBuilder<> * builder) const329 IrArray::Index IrArray::Index::SourceIndexOfBitcast(
330 const Shape& shape, const Shape& operand_shape,
331 llvm::IRBuilder<>* builder) const {
332 CHECK(LayoutUtil::HasLayout(shape) && LayoutUtil::HasLayout(operand_shape));
333
334 // In case the bitcast is just a reshape, we can use SourceIndexOfReshape()
335 // instead. This will reuse linear() if possible, so we don't have to build a
336 // new 'linear_index'.
337 if (ShapeUtil::ReshapeIsBitcast(operand_shape, shape)) {
338 return SourceIndexOfReshape(shape, operand_shape, builder);
339 }
340
341 // If we have a linear index, we can definitely use it because we know the
342 // operation is a bitcast. This will recompute the multi-dimensional index for
343 // the operand based on the linear index.
344 if (linear() != nullptr) {
345 return Index(linear(), operand_shape, builder);
346 }
347
348 // First linearize the index coming from the output of the bitcast. We want
349 // the physical index of the element in the buffer. This is like Linearize,
350 // but takes the layout into account.
351 int64_t scale = 1;
352 llvm::Value* linear_index = GetConstantWithIndexType(0);
353 for (auto dimension : LayoutUtil::MinorToMajor(shape)) {
354 linear_index = builder->CreateAdd(
355 linear_index,
356 builder->CreateMul(multidim_[dimension],
357 GetConstantWithIndexType(scale), "",
358 /*HasNUW=*/true, /*HasNSW=*/true),
359 "", /*HasNUW=*/true, /*HasNSW=*/true);
360 scale *= shape.dimensions(dimension);
361 }
362
363 return Index(linear_index, operand_shape, builder);
364 }
365
SourceIndexOfBroadcast(const Shape & shape,const Shape & operand_shape,absl::Span<const int64_t> dimension_mapping,llvm::IRBuilder<> * builder) const366 IrArray::Index IrArray::Index::SourceIndexOfBroadcast(
367 const Shape& shape, const Shape& operand_shape,
368 absl::Span<const int64_t> dimension_mapping,
369 llvm::IRBuilder<>* builder) const {
370 int64_t rank = operand_shape.rank();
371 std::vector<llvm::Value*> source_index(rank);
372 for (int64_t i = 0; i < rank; ++i) {
373 source_index[i] = multidim_[dimension_mapping[i]];
374 }
375 if (linear_ == nullptr || !LayoutUtil::HasLayout(operand_shape) ||
376 !LayoutUtil::HasLayout(shape) || rank == 1) {
377 return Index(source_index, operand_shape, index_type_);
378 }
379 // High-level idea: we can reuse the linear index if the broadcasted
380 // dimensions are contiguous, and this part of the operation is a bitcast.
381 // The other dimensions can be masked out with a div and a mod operation.
382 std::vector<int64_t> logical_to_physical =
383 LayoutUtil::MakeLogicalToPhysical(shape.layout());
384 int64_t output_rank = shape.rank();
385 // The minimum physical dimension that is broadcasted.
386 int64_t min_broadcasted_dimension = output_rank;
387 // The maximum physical dimension that is broadcasted.
388 int64_t max_broadcasted_dimension = -1;
389 for (int64_t i = 0; i < rank; ++i) {
390 int64_t physical_dim = logical_to_physical[dimension_mapping[i]];
391 min_broadcasted_dimension =
392 std::min(min_broadcasted_dimension, physical_dim);
393 max_broadcasted_dimension =
394 std::max(max_broadcasted_dimension, physical_dim);
395 }
396 bool contiguous_broadcast_dimensions =
397 max_broadcasted_dimension - min_broadcasted_dimension == rank - 1;
398 if (!contiguous_broadcast_dimensions) {
399 return Index(source_index, operand_shape, index_type_);
400 }
401 // Check if the mapped dimensions are a bitcast.
402 std::vector<int64_t> operand_logical_to_physical =
403 LayoutUtil::MakeLogicalToPhysical(operand_shape.layout());
404 for (int64_t i = 0; i < rank; ++i) {
405 if (operand_logical_to_physical[i] !=
406 logical_to_physical[dimension_mapping[i]] - min_broadcasted_dimension) {
407 return Index(source_index, operand_shape, index_type_);
408 }
409 }
410 llvm::Value* linear = linear_;
411 int64_t divisor = 1;
412 for (int64_t i = max_broadcasted_dimension + 1; i < output_rank; ++i) {
413 divisor *= shape.dimensions(LayoutUtil::Major(shape.layout(), i));
414 }
415 if (divisor > 1) {
416 linear = builder->CreateUDiv(linear, GetConstantWithIndexType(divisor));
417 }
418 if (min_broadcasted_dimension > 0) {
419 int64_t mod = 1;
420 for (int64_t i = min_broadcasted_dimension; i <= max_broadcasted_dimension;
421 ++i) {
422 mod *= shape.dimensions(LayoutUtil::Major(shape.layout(), i));
423 }
424 linear = builder->CreateURem(linear, GetConstantWithIndexType(mod));
425 }
426 return Index(source_index, linear, operand_shape, index_type_);
427 }
428
Linearize(absl::Span<const int64_t> dimensions,llvm::IRBuilder<> * builder) const429 llvm::Value* IrArray::Index::Linearize(absl::Span<const int64_t> dimensions,
430 llvm::IRBuilder<>* builder) const {
431 // Each dimension is multiplied by the product of the sizes of all
432 // earlier dimensions and added to the accumulator logical_linear_index.
433 CHECK_EQ(size(), dimensions.size());
434 llvm::Value* logical_linear_index = GetConstantWithIndexType(0);
435 int64_t multiplier = 1;
436 for (ssize_t i = size() - 1; i >= 0; --i) {
437 llvm::Value* addend =
438 builder->CreateMul((*this)[i], GetConstantWithIndexType(multiplier), "",
439 /*HasNUW=*/true, /*HasNSW=*/true);
440 addend = builder->CreateZExtOrTrunc(addend, index_type_);
441 logical_linear_index = builder->CreateAdd(logical_linear_index, addend, "",
442 /*HasNUW=*/true, /*HasNSW=*/true);
443 multiplier *= dimensions[i];
444 }
445 return logical_linear_index;
446 }
447
Linearize(const std::vector<llvm::Value * > & dynamic_dims,llvm::IRBuilder<> * builder) const448 llvm::Value* IrArray::Index::Linearize(
449 const std::vector<llvm::Value*>& dynamic_dims,
450 llvm::IRBuilder<>* builder) const {
451 // Each dimension is multiplied by the product of the sizes of all
452 // earlier dimensions and added to the accumulator logical_linear_index.
453 CHECK_EQ(size(), dynamic_dims.size());
454 llvm::Value* logical_linear_index = GetConstantWithIndexType(0);
455 llvm::Value* multiplier = GetConstantWithIndexType(1);
456 for (ssize_t i = size() - 1; i >= 0; --i) {
457 llvm::Value* addend = builder->CreateMul((*this)[i], multiplier, "",
458 /*HasNUW=*/true, /*HasNSW=*/true);
459 addend = builder->CreateZExtOrTrunc(addend, index_type_);
460 logical_linear_index = builder->CreateAdd(logical_linear_index, addend, "",
461 /*HasNUW=*/true, /*HasNSW=*/true);
462 if (i) {
463 multiplier = builder->CreateMul(multiplier, dynamic_dims[i],
464 /*Name=*/"multiplier");
465 }
466 }
467 return logical_linear_index;
468 }
469
EmitArrayElementAddress(const IrArray::Index & index,llvm::IRBuilder<> * b,absl::string_view name,bool use_linear_index) const470 llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index,
471 llvm::IRBuilder<>* b,
472 absl::string_view name,
473 bool use_linear_index) const {
474 if (ShapeUtil::IsScalar(shape_)) {
475 // Special handling of scalars: a scalar pretends to have the same value for
476 // every index, thus effectively implementing broadcasting of its value
477 // over higher-rank arrays.
478 return base_ptr_;
479 }
480 CHECK_EQ(index.size(), shape_.rank());
481 CHECK(index.ShapeIsCompatible(shape_))
482 << "Shape " << index.AsShapeWithType(shape_.element_type()).ToString(true)
483 << " is not compatible with " << shape_.ToString(true);
484
485 if (use_linear_index && index.LinearValidOnShape(shape_)) {
486 llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
487 llvm::Type* type = PrimitiveTypeToIrType(shape_.element_type(), module);
488 return b->CreateInBoundsGEP(
489 type, b->CreateBitCast(base_ptr_, type->getPointerTo()), index.linear(),
490 llvm_ir::AsStringRef(name));
491 }
492
493 std::vector<llvm::Value*> actual_index;
494 for (int64_t i = 0; i < index.size(); ++i) {
495 // When dimension i is of size 1, LLVM optimization is able to replace
496 // index[i] with 0. However, setting index[i] to 0 here still allows LLVM to
497 // produce better code in some cases.
498 auto dim = shape_.dimensions(i);
499 actual_index.push_back(
500 dim == 1 ? llvm::ConstantInt::get(index[i]->getType(), 0) : index[i]);
501 }
502
503 // "base_ptr_" has the type of "<ir_type_for_its_shape>*"
504 // (e.g. [3 x [2 x float]]*). Therefore, the address of the indexed element
505 // should be computed by
506 //
507 // getelementptr base_ptr_, 0, most major index, ..., most minor index
508 CHECK_GT(index.size(), 0);
509 std::vector<llvm::Value*> gep_indices(
510 1, llvm::ConstantInt::get(index[0]->getType(), 0));
511 for (int64_t i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) {
512 int64_t dimension = LayoutUtil::Major(shape_.layout(), i);
513 gep_indices.push_back(actual_index[dimension]);
514 }
515 return b->CreateInBoundsGEP(pointee_type_, base_ptr_, gep_indices,
516 llvm_ir::AsStringRef(name));
517 }
518
AnnotateLoadStoreInstructionWithMetadata(llvm::Instruction * instruction) const519 void IrArray::AnnotateLoadStoreInstructionWithMetadata(
520 llvm::Instruction* instruction) const {
521 CHECK(llvm::isa<llvm::LoadInst>(instruction) ||
522 llvm::isa<llvm::StoreInst>(instruction));
523 CHECK(!llvm::isa<llvm::StoreInst>(instruction) || !is_invariant_)
524 << "Trying to create a store to an invariant IRArray.";
525
526 for (const auto& kind_md_pair : metadata_) {
527 instruction->setMetadata(kind_md_pair.first, kind_md_pair.second);
528 }
529 }
530
EmitReadArrayElement(const Index & index,llvm::IRBuilder<> * b,absl::string_view name,bool use_linear_index) const531 llvm::Value* IrArray::EmitReadArrayElement(const Index& index,
532 llvm::IRBuilder<>* b,
533 absl::string_view name,
534 bool use_linear_index) const {
535 llvm::Value* element_address =
536 EmitArrayElementAddress(index, b, name, use_linear_index);
537 llvm::LoadInst* load =
538 b->CreateLoad(element_type_, element_address, llvm_ir::AsStringRef(name));
539 AnnotateLoadStoreInstructionWithMetadata(load);
540 return load;
541 }
542
EmitWriteArrayElement(const Index & index,llvm::Value * value,llvm::IRBuilder<> * b,bool use_linear_index) const543 void IrArray::EmitWriteArrayElement(const Index& index, llvm::Value* value,
544 llvm::IRBuilder<>* b,
545 bool use_linear_index) const {
546 llvm::Value* element_address =
547 EmitArrayElementAddress(index, b, "", use_linear_index);
548 llvm::StoreInst* store = b->CreateStore(value, element_address);
549 AnnotateLoadStoreInstructionWithMetadata(store);
550 }
551
CastToShape(const Shape & new_shape,llvm::IRBuilder<> * b) const552 IrArray IrArray::CastToShape(const Shape& new_shape,
553 llvm::IRBuilder<>* b) const {
554 llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
555 llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(new_shape, module);
556 IrArray new_irarray(
557 b->CreatePointerCast(base_ptr_, new_ir_type->getPointerTo()), new_ir_type,
558 new_shape);
559 new_irarray.metadata_ = metadata_;
560 return new_irarray;
561 }
562
ShapeIsCompatible(const Shape & a,const Shape & b)563 bool IrArray::Index::ShapeIsCompatible(const Shape& a, const Shape& b) {
564 // Compute strides for two sides of the comparison. Sometimes different shapes
565 // give the same strides:
566 // [10, 20, 30, 1]{3,2,1,0} vs [10, 20, 1, 30]{3,2,1,0}
567 // which should be considered compatible.
568 const auto get_strides = [](const Shape& shape) {
569 int rank = shape.dimensions().size();
570 int64_t stride = 1;
571 std::vector<int64_t> strides;
572 for (int i = 0; i < rank; i++) {
573 auto dim = shape.dimensions(shape.layout().minor_to_major(i));
574 if (dim != 1) {
575 stride *= dim;
576 strides.push_back(stride);
577 }
578 }
579 return strides;
580 };
581
582 return get_strides(a) == get_strides(b);
583 }
584
585 } // namespace llvm_ir
586 } // namespace xla
587