1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
17
18 #include <tuple>
19 #include <vector>
20
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/Instructions.h"
23 #include "llvm/IR/Value.h"
24 #include "tensorflow/compiler/xla/layout_util.h"
25 #include "tensorflow/compiler/xla/permutation_util.h"
26 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/types.h"
33
34 namespace xla {
35 namespace llvm_ir {
36
Index(absl::Span<llvm::Value * const> multidim,llvm::Value * linear,const Shape & shape,llvm::Type * index_type)37 IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
38 llvm::Value* linear, const Shape& shape,
39 llvm::Type* index_type)
40 : Index(multidim, shape, index_type) {
41 CHECK_NE(linear, nullptr);
42 linear_ = linear;
43 }
44
Delinearize(std::vector<llvm::Value * > * multidim,llvm::Value * linear,const Shape & shape,llvm::IRBuilder<> * b) const45 void IrArray::Index::Delinearize(std::vector<llvm::Value*>* multidim,
46 llvm::Value* linear, const Shape& shape,
47 llvm::IRBuilder<>* b) const {
48 int64 divisor = 1;
49 const Layout& layout = shape.layout();
50 for (int64 i = 0; i < layout.minor_to_major_size(); ++i) {
51 int64 dimension = layout.minor_to_major(i);
52 int64 size_of_current_dimension = shape.dimensions(dimension);
53
54 // If i is not the last dimension, compute
55 // (linear_index / divisor) % current_dimension.
56 // If i is the last dimension, we can skip the mod, because we assume that
57 // linear is in bounds.
58 //
59 // TODO(jlebar): We could add bounds checks here and elsewhere in this file,
60 // guarded under some sort of xla-memcheck flag. This might be particularly
61 // useful because cuda-memcheck can't help us much in XLA: Most of our
62 // memory lives in one big allocation, so cuda-memcheck can't detect
63 // out-of-bounds accesses.
64 auto* quot = b->CreateUDiv(linear, GetConstantWithIndexType(divisor));
65 if (i < layout.minor_to_major_size() - 1) {
66 (*multidim)[dimension] = b->CreateURem(
67 quot, GetConstantWithIndexType(size_of_current_dimension));
68 } else {
69 (*multidim)[dimension] = quot;
70 }
71 divisor *= size_of_current_dimension;
72 }
73 }
74
Delinearize(std::vector<llvm::Value * > * multidim,llvm::Value * linear,const Shape & shape,absl::Span<llvm::Value * > dynamic_dims,llvm::IRBuilder<> * b) const75 void IrArray::Index::Delinearize(std::vector<llvm::Value*>* multidim,
76 llvm::Value* linear, const Shape& shape,
77 absl::Span<llvm::Value*> dynamic_dims,
78 llvm::IRBuilder<>* b) const {
79 CHECK_EQ(shape.dimensions_size(), dynamic_dims.size());
80 CHECK_EQ(multidim_.size(), shape.rank());
81 llvm::Value* divisor = GetConstantWithIndexType(1);
82 const Layout& layout = shape.layout();
83 for (int64 i = 0; i < layout.minor_to_major_size(); ++i) {
84 int64 dimension = layout.minor_to_major(i);
85
86 // If i is not the last dimension, compute
87 // (linear_index / divisor) % current_dimension.
88 // If i is the last dimension, we can skip the mod, because we assume that
89 // linear is in bounds.
90 auto* quot = b->CreateUDiv(linear, divisor, "quot");
91 if (i < layout.minor_to_major_size() - 1) {
92 (*multidim)[dimension] =
93 b->CreateURem(quot, dynamic_dims[dimension], "dim_value");
94 divisor = b->CreateMul(divisor, dynamic_dims[dimension], "divisor");
95 } else {
96 (*multidim)[dimension] = quot;
97 }
98 }
99 }
100
Index(llvm::Value * linear,const Shape & shape,llvm::IRBuilder<> * b)101 IrArray::Index::Index(llvm::Value* linear, const Shape& shape,
102 llvm::IRBuilder<>* b)
103 : multidim_(shape.rank()),
104 linear_(linear),
105 layout_(shape.layout()),
106 dims_(shape.dimensions().begin(), shape.dimensions().end()) {
107 CHECK_NE(linear, nullptr);
108 index_type_ = linear->getType();
109 CHECK(LayoutUtil::HasLayout(shape))
110 << "Shape " << ShapeUtil::HumanStringWithLayout(shape)
111 << " should have a layout.";
112 Delinearize(&multidim_, linear, shape, b);
113 }
114
Index(llvm::Value * linear,const Shape & shape,absl::Span<llvm::Value * > dynamic_dims,llvm::IRBuilder<> * b)115 IrArray::Index::Index(llvm::Value* linear, const Shape& shape,
116 absl::Span<llvm::Value*> dynamic_dims,
117 llvm::IRBuilder<>* b)
118 : multidim_(shape.rank()),
119 linear_(linear),
120 layout_(shape.layout()),
121 dims_(shape.dimensions().begin(), shape.dimensions().end()) {
122 CHECK_NE(linear, nullptr);
123 index_type_ = linear->getType();
124 CHECK(LayoutUtil::HasLayout(shape))
125 << "Shape " << ShapeUtil::HumanStringWithLayout(shape)
126 << " should have a layout.";
127 Delinearize(&multidim_, linear, shape, dynamic_dims, b);
128 }
129
Index(absl::Span<llvm::Value * const> multidim,absl::Span<int64 const> dimensions,llvm::Type * index_type)130 IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
131 absl::Span<int64 const> dimensions,
132 llvm::Type* index_type)
133 : Index(multidim, ShapeUtil::MakeShape(/*arbitrary*/ PRED, dimensions),
134 index_type) {}
135
Index(absl::Span<llvm::Value * const> multidim,const Shape & shape,llvm::Type * index_type)136 IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
137 const Shape& shape, llvm::Type* index_type)
138 : multidim_(multidim.begin(), multidim.end()),
139 linear_(nullptr),
140 layout_(shape.layout()),
141 dims_(shape.dimensions().begin(), shape.dimensions().end()),
142 index_type_(index_type) {
143 CHECK_NE(index_type_, nullptr);
144 CHECK_EQ(shape.dimensions_size(), multidim.size());
145 for (const auto* dim : multidim) {
146 CHECK_NE(dim, nullptr);
147 }
148 CHECK(LayoutUtil::HasLayout(shape))
149 << "Shape " << ShapeUtil::HumanStringWithLayout(shape)
150 << " should have a layout.";
151 }
152
IrArray(llvm::Value * base_ptr,Shape shape)153 IrArray::IrArray(llvm::Value* base_ptr, Shape shape)
154 : base_ptr_(base_ptr), shape_(std::move(shape)) {
155 TF_CHECK_OK(ShapeUtil::ValidateShape(shape));
156 CHECK(base_ptr_->getType()->isPointerTy());
157 int depth = 0;
158 element_type_ =
159 llvm::cast<llvm::PointerType>(base_ptr_->getType())->getElementType();
160 while (llvm::ArrayType* array_type =
161 llvm::dyn_cast<llvm::ArrayType>(element_type_)) {
162 element_type_ = array_type->getElementType();
163 ++depth;
164 }
165
166 if (!shape_.IsArray() || ShapeUtil::IsScalar(shape_)) {
167 DCHECK(depth == 1 || depth == 0) << depth;
168 } else {
169 DCHECK_EQ(depth, shape_.rank()) << shape.ShortDebugString();
170 }
171 }
172
173 // Returns whether the given linear index is valid on the given shape.
LinearValidOnShape(const Shape & a) const174 bool IrArray::Index::LinearValidOnShape(const Shape& a) const {
175 auto b = ShapeUtil::MakeShape(a.element_type(), dims_);
176 *b.mutable_layout() = layout_;
177 return linear_ != nullptr &&
178 ShapeUtil::ElementsIn(a) == ShapeUtil::ElementsIn(b) &&
179 ShapeUtil::ReshapeIsBitcast(a, b);
180 }
181
SourceIndexOfReshape(const Shape & output_shape,const Shape & input_shape,llvm::IRBuilder<> * builder) const182 IrArray::Index IrArray::Index::SourceIndexOfReshape(
183 const Shape& output_shape, const Shape& input_shape,
184 llvm::IRBuilder<>* builder) const {
185 CHECK_EQ(multidim_.size(), output_shape.rank());
186 std::vector<llvm::Value*> source_multidim_index(
187 input_shape.rank(), llvm::UndefValue::get(index_type_));
188 auto trivial_reshape =
189 ShapeUtil::InsertedOrDeleted1SizedDimensions(input_shape, output_shape);
190 if (std::get<0>(trivial_reshape)) {
191 // The 1-sized dimensions which only appear in 'input_shape'.
192 auto deleted_dims_indices = std::get<1>(trivial_reshape);
193 // The 1-sized dimensions which only appear in 'output_shape'.
194 auto inserted_dims_indices = std::get<2>(trivial_reshape);
195
196 // This is a two-way merge of 'deleted_dims_indices' with indexing into
197 // 'source_multidim_index', and a two-way merge of 'inserted_dims_indices'
198 // with indexing into 'multidim_'. When we find a dimension in
199 // 'source_multidim_index' which does not belong to 'deleted_dims_indices',
200 // we retrieve the corresponding value from 'multidim_' (skipping any
201 // indices that appear in 'inserted_dims_indices').
202 for (int64 i = 0, j = 0, k = 0, l = 0; i < source_multidim_index.size();
203 ++i) {
204 if (j == deleted_dims_indices.size() || deleted_dims_indices[j] > i) {
205 // This is a dimension that was preserved. Take the matching value from
206 // multidim_.
207 while (l < inserted_dims_indices.size() &&
208 inserted_dims_indices[l] == k) {
209 // Skip 1-sized dimensions.
210 ++k;
211 ++l;
212 }
213 source_multidim_index[i] = multidim_[k];
214 ++k;
215 } else {
216 // This is a 1-sized dimension that only appears in the operand.
217 source_multidim_index[i] = GetConstantWithIndexType(0);
218 ++j;
219 }
220 }
221 } else {
222 const auto common_factors =
223 CommonFactors(AsInt64Slice(input_shape.dimensions()),
224 AsInt64Slice(output_shape.dimensions()));
225 // We compute the source indices in each common factor from only the target
226 // indices in the same common factor.
227 for (ssize_t k = common_factors.size() - 2; k >= 0; --k) {
228 absl::Span<int64 const> dimensions =
229 AsInt64Slice(output_shape.dimensions())
230 .subspan(common_factors[k].second,
231 common_factors[k + 1].second - common_factors[k].second);
232 llvm::Value* logical_linear_index =
233 Index(absl::Span<llvm::Value* const>(multidim_).subspan(
234 common_factors[k].second,
235 common_factors[k + 1].second - common_factors[k].second),
236 dimensions, index_type_)
237 .Linearize(dimensions, builder);
238 // Delinearizes logical_linear_index for the source array in row-major
239 // collapsed order. The first rank-1 indices are the remainder of the
240 // linear index by each dimension size.
241 for (int64 i = common_factors[k + 1].first - 1;
242 i >= common_factors[k].first; --i) {
243 llvm::Value* divisor =
244 GetConstantWithIndexType(input_shape.dimensions(i));
245 if (input_shape.dimensions(i) == 1) {
246 source_multidim_index[i] = GetConstantWithIndexType(0);
247 } else if (i == common_factors[k].first) {
248 source_multidim_index[i] = logical_linear_index;
249 } else {
250 source_multidim_index[i] =
251 builder->CreateURem(logical_linear_index, divisor);
252 }
253 logical_linear_index =
254 builder->CreateUDiv(logical_linear_index, divisor);
255 }
256 }
257 }
258
259 if (linear() != nullptr && LayoutUtil::HasLayout(input_shape) &&
260 LayoutUtil::HasLayout(output_shape) &&
261 ShapeUtil::ReshapeIsBitcast(input_shape, output_shape)) {
262 return Index(source_multidim_index, linear(), input_shape, index_type_);
263 }
264 return Index(source_multidim_index, input_shape, index_type_);
265 }
266
SourceIndexOfSlice(const Shape & operand_shape,absl::Span<const int64> starts,absl::Span<const int64> strides,llvm::IRBuilder<> * builder) const267 IrArray::Index IrArray::Index::SourceIndexOfSlice(
268 const Shape& operand_shape, absl::Span<const int64> starts,
269 absl::Span<const int64> strides, llvm::IRBuilder<>* builder) const {
270 std::vector<llvm::Value*> source_multi_index(multidim_.size());
271 for (int i = 0; i < multidim_.size(); ++i) {
272 int64 stride = strides[i];
273 if (stride != 1) {
274 source_multi_index[i] = builder->CreateAdd(
275 builder->CreateMul(multidim_[i], GetConstantWithIndexType(stride)),
276 GetConstantWithIndexType(starts[i]));
277 } else {
278 source_multi_index[i] =
279 builder->CreateAdd(multidim_[i], GetConstantWithIndexType(starts[i]));
280 }
281 }
282 return Index(source_multi_index, operand_shape, index_type_);
283 }
284
SourceIndexOfTranspose(const Shape & shape,const Shape & operand_shape,absl::Span<const int64> dimension_mapping) const285 IrArray::Index IrArray::Index::SourceIndexOfTranspose(
286 const Shape& shape, const Shape& operand_shape,
287 absl::Span<const int64> dimension_mapping) const {
288 std::vector<llvm::Value*> operand_multidim_index =
289 PermuteInverse(multidim(), dimension_mapping);
290
291 if (linear() != nullptr && LayoutUtil::HasLayout(operand_shape) &&
292 LayoutUtil::HasLayout(shape) &&
293 ShapeUtil::TransposeIsBitcast(operand_shape, shape, dimension_mapping)) {
294 return Index(operand_multidim_index, linear(), operand_shape, index_type_);
295 }
296
297 return Index(operand_multidim_index, operand_shape, index_type_);
298 }
299
SourceIndexOfBitcast(const Shape & shape,const Shape & operand_shape,llvm::IRBuilder<> * builder) const300 IrArray::Index IrArray::Index::SourceIndexOfBitcast(
301 const Shape& shape, const Shape& operand_shape,
302 llvm::IRBuilder<>* builder) const {
303 CHECK(LayoutUtil::HasLayout(shape) && LayoutUtil::HasLayout(operand_shape));
304
305 // In case the bitcast is just a reshape, we can use SourceIndexOfReshape()
306 // instead. This will reuse linear() if possible, so we don't have to build a
307 // new 'linear_index'.
308 if (ShapeUtil::ReshapeIsBitcast(operand_shape, shape)) {
309 return SourceIndexOfReshape(shape, operand_shape, builder);
310 }
311
312 // If we have a linear index, we can definitely use it because we know the
313 // operation is a bitcast. This will recompute the multi-dimensional index for
314 // the operand based on the linear index.
315 if (linear() != nullptr) {
316 return Index(linear(), operand_shape, builder);
317 }
318
319 // First linearize the index coming from the output of the bitcast. We want
320 // the physical index of the element in the buffer. This is like Linearize,
321 // but takes the layout into account.
322 int64 scale = 1;
323 llvm::Value* linear_index = GetConstantWithIndexType(0);
324 for (auto dimension : LayoutUtil::MinorToMajor(shape)) {
325 linear_index = builder->CreateAdd(
326 linear_index,
327 builder->CreateMul(multidim_[dimension],
328 GetConstantWithIndexType(scale), "",
329 /*HasNUW=*/true, /*HasNSW=*/true),
330 "", /*HasNUW=*/true, /*HasNSW=*/true);
331 scale *= shape.dimensions(dimension);
332 }
333
334 return Index(linear_index, operand_shape, builder);
335 }
336
SourceIndexOfBroadcast(const Shape & shape,const Shape & operand_shape,absl::Span<const int64> dimension_mapping,llvm::IRBuilder<> * builder) const337 IrArray::Index IrArray::Index::SourceIndexOfBroadcast(
338 const Shape& shape, const Shape& operand_shape,
339 absl::Span<const int64> dimension_mapping,
340 llvm::IRBuilder<>* builder) const {
341 int64 rank = operand_shape.rank();
342 std::vector<llvm::Value*> source_index(rank);
343 for (int64 i = 0; i < rank; ++i) {
344 source_index[i] = multidim_[dimension_mapping[i]];
345 }
346 if (linear_ == nullptr || !LayoutUtil::HasLayout(operand_shape) ||
347 !LayoutUtil::HasLayout(shape)) {
348 return Index(source_index, operand_shape, index_type_);
349 }
350 // High-level idea: we can reuse the linear index if the broadcasted
351 // dimensions are contiguous, and this part of the operation is a bitcast.
352 // The other dimensions can be masked out with a div and a mod operation.
353 std::vector<int64> logical_to_physical =
354 LayoutUtil::MakeLogicalToPhysical(shape.layout());
355 int64 output_rank = shape.rank();
356 // The minimum physical dimension that is broadcasted.
357 int64 min_broadcasted_dimension = output_rank;
358 // The maximum physical dimension that is broadcasted.
359 int64 max_broadcasted_dimension = -1;
360 for (int64 i = 0; i < rank; ++i) {
361 int64 physical_dim = logical_to_physical[dimension_mapping[i]];
362 min_broadcasted_dimension =
363 std::min(min_broadcasted_dimension, physical_dim);
364 max_broadcasted_dimension =
365 std::max(max_broadcasted_dimension, physical_dim);
366 }
367 bool contiguous_broadcast_dimensions =
368 max_broadcasted_dimension - min_broadcasted_dimension == rank - 1;
369 if (!contiguous_broadcast_dimensions) {
370 return Index(source_index, operand_shape, index_type_);
371 }
372 // Check if the mapped dimensions are a bitcast.
373 std::vector<int64> operand_logical_to_physical =
374 LayoutUtil::MakeLogicalToPhysical(operand_shape.layout());
375 for (int64 i = 0; i < rank; ++i) {
376 if (operand_logical_to_physical[i] !=
377 logical_to_physical[dimension_mapping[i]] - min_broadcasted_dimension) {
378 return Index(source_index, operand_shape, index_type_);
379 }
380 }
381 llvm::Value* linear = linear_;
382 int64 divisor = 1;
383 for (int64 i = max_broadcasted_dimension + 1; i < output_rank; ++i) {
384 divisor *= shape.dimensions(LayoutUtil::Major(shape.layout(), i));
385 }
386 if (divisor > 1) {
387 linear = builder->CreateUDiv(linear, GetConstantWithIndexType(divisor));
388 }
389 if (min_broadcasted_dimension > 0) {
390 int64 mod = 1;
391 for (int64 i = min_broadcasted_dimension; i <= max_broadcasted_dimension;
392 ++i) {
393 mod *= shape.dimensions(LayoutUtil::Major(shape.layout(), i));
394 }
395 linear = builder->CreateURem(linear, GetConstantWithIndexType(mod));
396 }
397 return Index(source_index, linear, operand_shape, index_type_);
398 }
399
Linearize(absl::Span<const int64> dimensions,llvm::IRBuilder<> * builder) const400 llvm::Value* IrArray::Index::Linearize(absl::Span<const int64> dimensions,
401 llvm::IRBuilder<>* builder) const {
402 // Each dimension is multiplied by the product of the sizes of all
403 // earlier dimensions and added to the accumulator logical_linear_index.
404 CHECK_EQ(size(), dimensions.size());
405 llvm::Value* logical_linear_index = GetConstantWithIndexType(0);
406 int64 multiplier = 1;
407 for (ssize_t i = size() - 1; i >= 0; --i) {
408 llvm::Value* addend =
409 builder->CreateMul((*this)[i], GetConstantWithIndexType(multiplier), "",
410 /*HasNUW=*/true, /*HasNSW=*/true);
411 addend = builder->CreateZExtOrTrunc(addend, index_type_);
412 logical_linear_index = builder->CreateAdd(logical_linear_index, addend, "",
413 /*HasNUW=*/true, /*HasNSW=*/true);
414 multiplier *= dimensions[i];
415 }
416 return logical_linear_index;
417 }
418
Linearize(const std::vector<llvm::Value * > & dynamic_dims,llvm::IRBuilder<> * builder) const419 llvm::Value* IrArray::Index::Linearize(
420 const std::vector<llvm::Value*>& dynamic_dims,
421 llvm::IRBuilder<>* builder) const {
422 // Each dimension is multiplied by the product of the sizes of all
423 // earlier dimensions and added to the accumulator logical_linear_index.
424 CHECK_EQ(size(), dynamic_dims.size());
425 llvm::Value* logical_linear_index = GetConstantWithIndexType(0);
426 llvm::Value* multiplier = GetConstantWithIndexType(1);
427 for (ssize_t i = size() - 1; i >= 0; --i) {
428 llvm::Value* addend = builder->CreateMul((*this)[i], multiplier, "",
429 /*HasNUW=*/true, /*HasNSW=*/true);
430 addend = builder->CreateZExtOrTrunc(addend, index_type_);
431 logical_linear_index = builder->CreateAdd(logical_linear_index, addend, "",
432 /*HasNUW=*/true, /*HasNSW=*/true);
433 if (i) {
434 multiplier = builder->CreateMul(multiplier, dynamic_dims[i],
435 /*Name=*/"multiplier");
436 }
437 }
438 return logical_linear_index;
439 }
440
EmitArrayElementAddress(const IrArray::Index & index,llvm::IRBuilder<> * b,absl::string_view name,bool use_linear_index) const441 llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index,
442 llvm::IRBuilder<>* b,
443 absl::string_view name,
444 bool use_linear_index) const {
445 if (ShapeUtil::IsScalar(shape_)) {
446 // Special handling of scalars: a scalar pretends to have the same value for
447 // every index, thus effectively implementing broadcasting of its value
448 // over higher-rank arrays.
449 return base_ptr_;
450 }
451 CHECK_EQ(index.size(), shape_.rank());
452 CHECK(index.ShapeIsCompatible(shape_));
453
454 if (use_linear_index && index.LinearValidOnShape(shape_)) {
455 llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
456 return b->CreateInBoundsGEP(
457 b->CreateBitCast(base_ptr_,
458 PrimitiveTypeToIrType(shape_.element_type(), module)
459 ->getPointerTo()),
460 {index.linear()}, llvm_ir::AsStringRef(name));
461 }
462
463 std::vector<llvm::Value*> actual_index;
464 for (int64 i = 0; i < index.size(); ++i) {
465 // When dimension i is of size 1, LLVM optimization is able to replace
466 // index[i] with 0. However, setting index[i] to 0 here still allows LLVM to
467 // produce better code in some cases.
468 auto dim = shape_.dimensions(i);
469 actual_index.push_back(
470 dim == 1 ? llvm::ConstantInt::get(index[i]->getType(), 0) : index[i]);
471 }
472
473 // "base_ptr_" has the type of "<ir_type_for_its_shape>*"
474 // (e.g. [3 x [2 x float]]*). Therefore, the address of the indexed element
475 // should be computed by
476 //
477 // getelementptr base_ptr_, 0, most major index, ..., most minor index
478 CHECK_GT(index.size(), 0);
479 std::vector<llvm::Value*> gep_indices(
480 1, llvm::ConstantInt::get(index[0]->getType(), 0));
481 for (int64 i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) {
482 int64 dimension = LayoutUtil::Major(shape_.layout(), i);
483 gep_indices.push_back(actual_index[dimension]);
484 }
485 return b->CreateInBoundsGEP(base_ptr_, gep_indices,
486 llvm_ir::AsStringRef(name));
487 }
488
AnnotateLoadStoreInstructionWithMetadata(llvm::Instruction * instruction) const489 void IrArray::AnnotateLoadStoreInstructionWithMetadata(
490 llvm::Instruction* instruction) const {
491 CHECK(llvm::isa<llvm::LoadInst>(instruction) ||
492 llvm::isa<llvm::StoreInst>(instruction));
493 CHECK(!llvm::isa<llvm::StoreInst>(instruction) || !is_invariant_)
494 << "Trying to create a store to an invariant IRArray.";
495
496 for (const auto& kind_md_pair : metadata_) {
497 instruction->setMetadata(kind_md_pair.first, kind_md_pair.second);
498 }
499 }
500
EmitReadArrayElement(const Index & index,llvm::IRBuilder<> * b,absl::string_view name,bool use_linear_index) const501 llvm::Value* IrArray::EmitReadArrayElement(const Index& index,
502 llvm::IRBuilder<>* b,
503 absl::string_view name,
504 bool use_linear_index) const {
505 llvm::Value* element_address =
506 EmitArrayElementAddress(index, b, name, use_linear_index);
507 llvm::LoadInst* load = b->CreateLoad(element_address, name.data());
508 AnnotateLoadStoreInstructionWithMetadata(load);
509 return load;
510 }
511
EmitWriteArrayElement(const Index & index,llvm::Value * value,llvm::IRBuilder<> * b,bool use_linear_index) const512 void IrArray::EmitWriteArrayElement(const Index& index, llvm::Value* value,
513 llvm::IRBuilder<>* b,
514 bool use_linear_index) const {
515 llvm::Value* element_address =
516 EmitArrayElementAddress(index, b, "", use_linear_index);
517 llvm::StoreInst* store = b->CreateStore(value, element_address);
518 AnnotateLoadStoreInstructionWithMetadata(store);
519 }
520
CastToShape(const Shape & new_shape,llvm::IRBuilder<> * b) const521 IrArray IrArray::CastToShape(const Shape& new_shape,
522 llvm::IRBuilder<>* b) const {
523 llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
524 llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(new_shape, module);
525 IrArray new_irarray(
526 b->CreatePointerCast(base_ptr_, new_ir_type->getPointerTo()), new_shape);
527 new_irarray.metadata_ = metadata_;
528 return new_irarray;
529 }
530
ShapeIsCompatible(const Shape & a,const Shape & b)531 bool IrArray::Index::ShapeIsCompatible(const Shape& a, const Shape& b) {
532 // Compute strides for two sides of the comparison. Sometimes different shapes
533 // give the same strides:
534 // [10, 20, 30, 1]{3,2,1,0} vs [10, 20, 1, 30]{3,2,1,0}
535 // which should be considered compatible.
536 const auto get_strides = [](const Shape& shape) {
537 int rank = shape.dimensions().size();
538 int64 stride = 1;
539 std::vector<int64> strides;
540 for (int i = 0; i < rank; i++) {
541 auto dim = shape.dimensions(shape.layout().minor_to_major(i));
542 if (dim != 1) {
543 stride *= dim;
544 strides.push_back(stride);
545 }
546 }
547 return strides;
548 };
549
550 return get_strides(a) == get_strides(b);
551 }
552
553 } // namespace llvm_ir
554 } // namespace xla
555