• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
17 
18 #include <optional>
19 #include <tuple>
20 #include <utility>
21 #include <vector>
22 
23 #include "llvm/IR/Constants.h"
24 #include "llvm/IR/Instructions.h"
25 #include "llvm/IR/Value.h"
26 #include "tensorflow/compiler/xla/layout_util.h"
27 #include "tensorflow/compiler/xla/permutation_util.h"
28 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_type_conversion_util.h"
29 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/compiler/xla/statusor.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/platform/logging.h"
35 
36 namespace xla {
37 namespace llvm_ir {
38 
Index(absl::Span<llvm::Value * const> multidim,llvm::Value * linear,const Shape & shape,llvm::Type * index_type)39 IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
40                       llvm::Value* linear, const Shape& shape,
41                       llvm::Type* index_type)
42     : Index(multidim, shape, index_type) {
43   CHECK_NE(linear, nullptr);
44   linear_ = linear;
45 }
46 
Delinearize(std::vector<llvm::Value * > * multidim,llvm::Value * linear,const Shape & shape,llvm::IRBuilder<> * b) const47 void IrArray::Index::Delinearize(std::vector<llvm::Value*>* multidim,
48                                  llvm::Value* linear, const Shape& shape,
49                                  llvm::IRBuilder<>* b) const {
50   int64_t divisor = 1;
51   const Layout& layout = shape.layout();
52   for (int64_t i = 0; i < layout.minor_to_major_size(); ++i) {
53     int64_t dimension = layout.minor_to_major(i);
54     int64_t size_of_current_dimension = shape.dimensions(dimension);
55 
56     // If i is not the last dimension, compute
57     //   (linear_index / divisor) % current_dimension.
58     // If i is the last dimension, we can skip the mod, because we assume that
59     // linear is in bounds.
60     //
61     // TODO(jlebar): We could add bounds checks here and elsewhere in this file,
62     // guarded under some sort of xla-memcheck flag.  This might be particularly
63     // useful because cuda-memcheck can't help us much in XLA: Most of our
64     // memory lives in one big allocation, so cuda-memcheck can't detect
65     // out-of-bounds accesses.
66     auto* quot = b->CreateUDiv(linear, GetConstantWithIndexType(divisor));
67     if (i < layout.minor_to_major_size() - 1) {
68       (*multidim)[dimension] = b->CreateURem(
69           quot, GetConstantWithIndexType(size_of_current_dimension));
70     } else {
71       (*multidim)[dimension] = quot;
72     }
73     divisor *= size_of_current_dimension;
74   }
75 }
76 
Delinearize(std::vector<llvm::Value * > * multidim,llvm::Value * linear,const Shape & shape,absl::Span<llvm::Value * > dynamic_dims,llvm::IRBuilder<> * b) const77 void IrArray::Index::Delinearize(std::vector<llvm::Value*>* multidim,
78                                  llvm::Value* linear, const Shape& shape,
79                                  absl::Span<llvm::Value*> dynamic_dims,
80                                  llvm::IRBuilder<>* b) const {
81   CHECK_EQ(shape.dimensions_size(), dynamic_dims.size());
82   CHECK_EQ(multidim_.size(), shape.rank());
83   llvm::Value* divisor = GetConstantWithIndexType(1);
84   const Layout& layout = shape.layout();
85   for (int64_t i = 0; i < layout.minor_to_major_size(); ++i) {
86     int64_t dimension = layout.minor_to_major(i);
87 
88     // If i is not the last dimension, compute
89     //   (linear_index / divisor) % current_dimension.
90     // If i is the last dimension, we can skip the mod, because we assume that
91     // linear is in bounds.
92     auto* quot = b->CreateUDiv(linear, divisor, "quot");
93     if (i < layout.minor_to_major_size() - 1) {
94       llvm::Value* casted_dynamic_dim =
95           b->CreateIntCast(dynamic_dims[dimension], quot->getType(),
96                            /*isSigned=*/true);
97       (*multidim)[dimension] =
98           b->CreateURem(quot, casted_dynamic_dim, "dim_value");
99       divisor = b->CreateMul(divisor, casted_dynamic_dim, "divisor");
100     } else {
101       (*multidim)[dimension] = quot;
102     }
103   }
104 }
105 
Index(llvm::Value * linear,const Shape & shape,llvm::IRBuilder<> * b)106 IrArray::Index::Index(llvm::Value* linear, const Shape& shape,
107                       llvm::IRBuilder<>* b)
108     : multidim_(shape.rank()),
109       linear_(linear),
110       layout_(shape.layout()),
111       dims_(shape.dimensions().begin(), shape.dimensions().end()) {
112   CHECK_NE(linear, nullptr);
113   index_type_ = linear->getType();
114   CHECK(LayoutUtil::HasLayout(shape))
115       << "Shape " << ShapeUtil::HumanStringWithLayout(shape)
116       << " should have a layout.";
117   Delinearize(&multidim_, linear, shape, b);
118 }
119 
Index(llvm::Value * linear,absl::Span<llvm::Value * const> multidim,const Shape & shape,llvm::IRBuilder<> * b)120 IrArray::Index::Index(llvm::Value* linear,
121                       absl::Span<llvm::Value* const> multidim,
122                       const Shape& shape, llvm::IRBuilder<>* b)
123     : multidim_(shape.rank()),
124       linear_(linear),
125       layout_(shape.layout()),
126       dims_(shape.dimensions().begin(), shape.dimensions().end()) {
127   CHECK_NE(linear, nullptr);
128   index_type_ = linear->getType();
129   CHECK_EQ(multidim.size(), shape.rank());
130   for (auto dim : multidim) {
131     if (dim) {
132       CHECK_EQ(dim->getType(), index_type_);
133     }
134   }
135   CHECK(LayoutUtil::HasLayout(shape))
136       << "Shape " << ShapeUtil::HumanStringWithLayout(shape)
137       << " should have a layout.";
138   Delinearize(&multidim_, linear, shape, b);
139   for (int i = 0; i < multidim.size(); ++i) {
140     if (multidim[i] != nullptr) {
141       multidim_[i] = multidim[i];
142     }
143   }
144 }
145 
Index(llvm::Value * linear,const Shape & shape,absl::Span<llvm::Value * > dynamic_dims,llvm::IRBuilder<> * b)146 IrArray::Index::Index(llvm::Value* linear, const Shape& shape,
147                       absl::Span<llvm::Value*> dynamic_dims,
148                       llvm::IRBuilder<>* b)
149     : multidim_(shape.rank()),
150       linear_(linear),
151       layout_(shape.layout()),
152       dims_(shape.dimensions().begin(), shape.dimensions().end()) {
153   CHECK_NE(linear, nullptr);
154   index_type_ = linear->getType();
155   CHECK(LayoutUtil::HasLayout(shape))
156       << "Shape " << ShapeUtil::HumanStringWithLayout(shape)
157       << " should have a layout.";
158   Delinearize(&multidim_, linear, shape, dynamic_dims, b);
159 }
160 
Index(absl::Span<llvm::Value * const> multidim,absl::Span<int64_t const> dimensions,llvm::Type * index_type)161 IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
162                       absl::Span<int64_t const> dimensions,
163                       llvm::Type* index_type)
164     : Index(multidim, ShapeUtil::MakeShape(/*arbitrary*/ PRED, dimensions),
165             index_type) {}
166 
Index(absl::Span<llvm::Value * const> multidim,const Shape & shape,llvm::Type * index_type)167 IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
168                       const Shape& shape, llvm::Type* index_type)
169     : multidim_(multidim.begin(), multidim.end()),
170       linear_(nullptr),
171       layout_(shape.layout()),
172       dims_(shape.dimensions().begin(), shape.dimensions().end()),
173       index_type_(index_type) {
174   CHECK_NE(index_type_, nullptr);
175   CHECK_EQ(shape.dimensions_size(), multidim.size());
176   for (const auto* dim : multidim) {
177     CHECK_NE(dim, nullptr);
178   }
179   CHECK(LayoutUtil::HasLayout(shape))
180       << "Shape " << ShapeUtil::HumanStringWithLayout(shape)
181       << " should have a layout.";
182 }
183 
IrArray(llvm::Value * base_ptr,llvm::Type * pointee_type,Shape shape)184 IrArray::IrArray(llvm::Value* base_ptr, llvm::Type* pointee_type, Shape shape)
185     : base_ptr_(base_ptr),
186       pointee_type_(pointee_type),
187       shape_(std::move(shape)) {
188   TF_CHECK_OK(ShapeUtil::ValidateShape(shape));
189   CHECK(base_ptr_->getType()->isPointerTy());
190   CHECK(llvm::cast<llvm::PointerType>(base_ptr_->getType())
191             ->isOpaqueOrPointeeTypeMatches(pointee_type));
192   int depth = 0;
193   element_type_ = pointee_type;
194   while (llvm::ArrayType* array_type =
195              llvm::dyn_cast<llvm::ArrayType>(element_type_)) {
196     element_type_ = array_type->getElementType();
197     ++depth;
198   }
199 
200   if (!shape_.IsArray() || ShapeUtil::IsScalar(shape_)) {
201     DCHECK(depth == 1 || depth == 0) << depth;
202   } else {
203     DCHECK_EQ(depth, shape_.rank()) << shape.ShortDebugString();
204   }
205 }
206 
207 // Returns whether the given linear index is valid on the given shape.
LinearValidOnShape(const Shape & a) const208 bool IrArray::Index::LinearValidOnShape(const Shape& a) const {
209   auto b = ShapeUtil::MakeShape(a.element_type(), dims_);
210   *b.mutable_layout() = layout_;
211   return linear_ != nullptr &&
212          ShapeUtil::ElementsIn(a) == ShapeUtil::ElementsIn(b) &&
213          ShapeUtil::ReshapeIsBitcast(a, b);
214 }
215 
SourceIndexOfReshape(const Shape & output_shape,const Shape & input_shape,llvm::IRBuilder<> * builder) const216 IrArray::Index IrArray::Index::SourceIndexOfReshape(
217     const Shape& output_shape, const Shape& input_shape,
218     llvm::IRBuilder<>* builder) const {
219   CHECK_EQ(multidim_.size(), output_shape.rank());
220   std::vector<llvm::Value*> source_multidim_index(
221       input_shape.rank(), llvm::UndefValue::get(index_type_));
222 
223   if (std::optional<ShapeUtil::ShapeEqualityDescriptor> trivial_reshape =
224           ShapeUtil::InsertedOrDeleted1SizedDimensions(input_shape,
225                                                        output_shape)) {
226     // This is a two-way merge of 'deleted_dims_indices' with indexing into
227     // 'source_multidim_index', and a two-way merge of 'inserted_dims_indices'
228     // with indexing into 'multidim_'. When we find a dimension in
229     // 'source_multidim_index' which does not belong to 'deleted_dims_indices',
230     // we retrieve the corresponding value from 'multidim_' (skipping any
231     // indices that appear in 'inserted_dims_indices').
232     for (int64_t i = 0, j = 0, k = 0, l = 0; i < source_multidim_index.size();
233          ++i) {
234       if (j == trivial_reshape->deleted_dimensions.size() ||
235           trivial_reshape->deleted_dimensions[j] > i) {
236         // This is a dimension that was preserved. Take the matching value from
237         // multidim_.
238         while (l < trivial_reshape->inserted_dimensions.size() &&
239                trivial_reshape->inserted_dimensions[l] == k) {
240           // Skip 1-sized dimensions.
241           ++k;
242           ++l;
243         }
244         source_multidim_index[i] = multidim_[k];
245         ++k;
246       } else {
247         // This is a 1-sized dimension that only appears in the operand.
248         source_multidim_index[i] = GetConstantWithIndexType(0);
249         ++j;
250       }
251     }
252   } else {
253     const auto common_factors =
254         CommonFactors(input_shape.dimensions(), output_shape.dimensions());
255     // We compute the source indices in each common factor from only the target
256     // indices in the same common factor.
257     for (ssize_t k = common_factors.size() - 2; k >= 0; --k) {
258       absl::Span<int64_t const> dimensions = output_shape.dimensions().subspan(
259           common_factors[k].second,
260           common_factors[k + 1].second - common_factors[k].second);
261       llvm::Value* logical_linear_index =
262           Index(absl::Span<llvm::Value* const>(multidim_).subspan(
263                     common_factors[k].second,
264                     common_factors[k + 1].second - common_factors[k].second),
265                 dimensions, index_type_)
266               .Linearize(dimensions, builder);
267       // Delinearizes logical_linear_index for the source array in row-major
268       // collapsed order. The first rank-1 indices are the remainder of the
269       // linear index by each dimension size.
270       for (int64_t i = common_factors[k + 1].first - 1;
271            i >= common_factors[k].first; --i) {
272         llvm::Value* divisor =
273             GetConstantWithIndexType(input_shape.dimensions(i));
274         if (input_shape.dimensions(i) == 1) {
275           source_multidim_index[i] = GetConstantWithIndexType(0);
276         } else if (i == common_factors[k].first) {
277           source_multidim_index[i] = logical_linear_index;
278         } else {
279           source_multidim_index[i] =
280               builder->CreateURem(logical_linear_index, divisor);
281         }
282         logical_linear_index =
283             builder->CreateUDiv(logical_linear_index, divisor);
284       }
285     }
286   }
287 
288   if (linear() != nullptr && LayoutUtil::HasLayout(input_shape) &&
289       LayoutUtil::HasLayout(output_shape) &&
290       ShapeUtil::ReshapeIsBitcast(input_shape, output_shape)) {
291     return Index(source_multidim_index, linear(), input_shape, index_type_);
292   }
293   return Index(source_multidim_index, input_shape, index_type_);
294 }
295 
SourceIndexOfSlice(const Shape & operand_shape,absl::Span<const int64_t> starts,absl::Span<const int64_t> strides,llvm::IRBuilder<> * builder) const296 IrArray::Index IrArray::Index::SourceIndexOfSlice(
297     const Shape& operand_shape, absl::Span<const int64_t> starts,
298     absl::Span<const int64_t> strides, llvm::IRBuilder<>* builder) const {
299   std::vector<llvm::Value*> source_multi_index(multidim_.size());
300   for (int i = 0; i < multidim_.size(); ++i) {
301     int64_t stride = strides[i];
302     if (stride != 1) {
303       source_multi_index[i] = builder->CreateAdd(
304           builder->CreateMul(multidim_[i], GetConstantWithIndexType(stride)),
305           GetConstantWithIndexType(starts[i]));
306     } else {
307       source_multi_index[i] =
308           builder->CreateAdd(multidim_[i], GetConstantWithIndexType(starts[i]));
309     }
310   }
311   return Index(source_multi_index, operand_shape, index_type_);
312 }
313 
SourceIndexOfTranspose(const Shape & shape,const Shape & operand_shape,absl::Span<const int64_t> dimension_mapping) const314 IrArray::Index IrArray::Index::SourceIndexOfTranspose(
315     const Shape& shape, const Shape& operand_shape,
316     absl::Span<const int64_t> dimension_mapping) const {
317   std::vector<llvm::Value*> operand_multidim_index =
318       PermuteInverse(multidim(), dimension_mapping);
319 
320   if (linear() != nullptr && LayoutUtil::HasLayout(operand_shape) &&
321       LayoutUtil::HasLayout(shape) &&
322       ShapeUtil::TransposeIsBitcast(operand_shape, shape, dimension_mapping)) {
323     return Index(operand_multidim_index, linear(), operand_shape, index_type_);
324   }
325 
326   return Index(operand_multidim_index, operand_shape, index_type_);
327 }
328 
SourceIndexOfBitcast(const Shape & shape,const Shape & operand_shape,llvm::IRBuilder<> * builder) const329 IrArray::Index IrArray::Index::SourceIndexOfBitcast(
330     const Shape& shape, const Shape& operand_shape,
331     llvm::IRBuilder<>* builder) const {
332   CHECK(LayoutUtil::HasLayout(shape) && LayoutUtil::HasLayout(operand_shape));
333 
334   // In case the bitcast is just a reshape, we can use SourceIndexOfReshape()
335   // instead. This will reuse linear() if possible, so we don't have to build a
336   // new 'linear_index'.
337   if (ShapeUtil::ReshapeIsBitcast(operand_shape, shape)) {
338     return SourceIndexOfReshape(shape, operand_shape, builder);
339   }
340 
341   // If we have a linear index, we can definitely use it because we know the
342   // operation is a bitcast. This will recompute the multi-dimensional index for
343   // the operand based on the linear index.
344   if (linear() != nullptr) {
345     return Index(linear(), operand_shape, builder);
346   }
347 
348   // First linearize the index coming from the output of the bitcast. We want
349   // the physical index of the element in the buffer. This is like Linearize,
350   // but takes the layout into account.
351   int64_t scale = 1;
352   llvm::Value* linear_index = GetConstantWithIndexType(0);
353   for (auto dimension : LayoutUtil::MinorToMajor(shape)) {
354     linear_index = builder->CreateAdd(
355         linear_index,
356         builder->CreateMul(multidim_[dimension],
357                            GetConstantWithIndexType(scale), "",
358                            /*HasNUW=*/true, /*HasNSW=*/true),
359         "", /*HasNUW=*/true, /*HasNSW=*/true);
360     scale *= shape.dimensions(dimension);
361   }
362 
363   return Index(linear_index, operand_shape, builder);
364 }
365 
SourceIndexOfBroadcast(const Shape & shape,const Shape & operand_shape,absl::Span<const int64_t> dimension_mapping,llvm::IRBuilder<> * builder) const366 IrArray::Index IrArray::Index::SourceIndexOfBroadcast(
367     const Shape& shape, const Shape& operand_shape,
368     absl::Span<const int64_t> dimension_mapping,
369     llvm::IRBuilder<>* builder) const {
370   int64_t rank = operand_shape.rank();
371   std::vector<llvm::Value*> source_index(rank);
372   for (int64_t i = 0; i < rank; ++i) {
373     source_index[i] = multidim_[dimension_mapping[i]];
374   }
375   if (linear_ == nullptr || !LayoutUtil::HasLayout(operand_shape) ||
376       !LayoutUtil::HasLayout(shape) || rank == 1) {
377     return Index(source_index, operand_shape, index_type_);
378   }
379   // High-level idea: we can reuse the linear index if the broadcasted
380   // dimensions are contiguous, and this part of the operation is a bitcast.
381   // The other dimensions can be masked out with a div and a mod operation.
382   std::vector<int64_t> logical_to_physical =
383       LayoutUtil::MakeLogicalToPhysical(shape.layout());
384   int64_t output_rank = shape.rank();
385   // The minimum physical dimension that is broadcasted.
386   int64_t min_broadcasted_dimension = output_rank;
387   // The maximum physical dimension that is broadcasted.
388   int64_t max_broadcasted_dimension = -1;
389   for (int64_t i = 0; i < rank; ++i) {
390     int64_t physical_dim = logical_to_physical[dimension_mapping[i]];
391     min_broadcasted_dimension =
392         std::min(min_broadcasted_dimension, physical_dim);
393     max_broadcasted_dimension =
394         std::max(max_broadcasted_dimension, physical_dim);
395   }
396   bool contiguous_broadcast_dimensions =
397       max_broadcasted_dimension - min_broadcasted_dimension == rank - 1;
398   if (!contiguous_broadcast_dimensions) {
399     return Index(source_index, operand_shape, index_type_);
400   }
401   // Check if the mapped dimensions are a bitcast.
402   std::vector<int64_t> operand_logical_to_physical =
403       LayoutUtil::MakeLogicalToPhysical(operand_shape.layout());
404   for (int64_t i = 0; i < rank; ++i) {
405     if (operand_logical_to_physical[i] !=
406         logical_to_physical[dimension_mapping[i]] - min_broadcasted_dimension) {
407       return Index(source_index, operand_shape, index_type_);
408     }
409   }
410   llvm::Value* linear = linear_;
411   int64_t divisor = 1;
412   for (int64_t i = max_broadcasted_dimension + 1; i < output_rank; ++i) {
413     divisor *= shape.dimensions(LayoutUtil::Major(shape.layout(), i));
414   }
415   if (divisor > 1) {
416     linear = builder->CreateUDiv(linear, GetConstantWithIndexType(divisor));
417   }
418   if (min_broadcasted_dimension > 0) {
419     int64_t mod = 1;
420     for (int64_t i = min_broadcasted_dimension; i <= max_broadcasted_dimension;
421          ++i) {
422       mod *= shape.dimensions(LayoutUtil::Major(shape.layout(), i));
423     }
424     linear = builder->CreateURem(linear, GetConstantWithIndexType(mod));
425   }
426   return Index(source_index, linear, operand_shape, index_type_);
427 }
428 
Linearize(absl::Span<const int64_t> dimensions,llvm::IRBuilder<> * builder) const429 llvm::Value* IrArray::Index::Linearize(absl::Span<const int64_t> dimensions,
430                                        llvm::IRBuilder<>* builder) const {
431   // Each dimension is multiplied by the product of the sizes of all
432   // earlier dimensions and added to the accumulator logical_linear_index.
433   CHECK_EQ(size(), dimensions.size());
434   llvm::Value* logical_linear_index = GetConstantWithIndexType(0);
435   int64_t multiplier = 1;
436   for (ssize_t i = size() - 1; i >= 0; --i) {
437     llvm::Value* addend =
438         builder->CreateMul((*this)[i], GetConstantWithIndexType(multiplier), "",
439                            /*HasNUW=*/true, /*HasNSW=*/true);
440     addend = builder->CreateZExtOrTrunc(addend, index_type_);
441     logical_linear_index = builder->CreateAdd(logical_linear_index, addend, "",
442                                               /*HasNUW=*/true, /*HasNSW=*/true);
443     multiplier *= dimensions[i];
444   }
445   return logical_linear_index;
446 }
447 
Linearize(const std::vector<llvm::Value * > & dynamic_dims,llvm::IRBuilder<> * builder) const448 llvm::Value* IrArray::Index::Linearize(
449     const std::vector<llvm::Value*>& dynamic_dims,
450     llvm::IRBuilder<>* builder) const {
451   // Each dimension is multiplied by the product of the sizes of all
452   // earlier dimensions and added to the accumulator logical_linear_index.
453   CHECK_EQ(size(), dynamic_dims.size());
454   llvm::Value* logical_linear_index = GetConstantWithIndexType(0);
455   llvm::Value* multiplier = GetConstantWithIndexType(1);
456   for (ssize_t i = size() - 1; i >= 0; --i) {
457     llvm::Value* addend = builder->CreateMul((*this)[i], multiplier, "",
458                                              /*HasNUW=*/true, /*HasNSW=*/true);
459     addend = builder->CreateZExtOrTrunc(addend, index_type_);
460     logical_linear_index = builder->CreateAdd(logical_linear_index, addend, "",
461                                               /*HasNUW=*/true, /*HasNSW=*/true);
462     if (i) {
463       multiplier = builder->CreateMul(multiplier, dynamic_dims[i],
464                                       /*Name=*/"multiplier");
465     }
466   }
467   return logical_linear_index;
468 }
469 
EmitArrayElementAddress(const IrArray::Index & index,llvm::IRBuilder<> * b,absl::string_view name,bool use_linear_index) const470 llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index,
471                                               llvm::IRBuilder<>* b,
472                                               absl::string_view name,
473                                               bool use_linear_index) const {
474   if (ShapeUtil::IsScalar(shape_)) {
475     // Special handling of scalars: a scalar pretends to have the same value for
476     // every index, thus effectively implementing broadcasting of its value
477     // over higher-rank arrays.
478     return base_ptr_;
479   }
480   CHECK_EQ(index.size(), shape_.rank());
481   CHECK(index.ShapeIsCompatible(shape_))
482       << "Shape " << index.AsShapeWithType(shape_.element_type()).ToString(true)
483       << " is not compatible with " << shape_.ToString(true);
484 
485   if (use_linear_index && index.LinearValidOnShape(shape_)) {
486     llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
487     llvm::Type* type = PrimitiveTypeToIrType(shape_.element_type(), module);
488     return b->CreateInBoundsGEP(
489         type, b->CreateBitCast(base_ptr_, type->getPointerTo()), index.linear(),
490         llvm_ir::AsStringRef(name));
491   }
492 
493   std::vector<llvm::Value*> actual_index;
494   for (int64_t i = 0; i < index.size(); ++i) {
495     // When dimension i is of size 1, LLVM optimization is able to replace
496     // index[i] with 0. However, setting index[i] to 0 here still allows LLVM to
497     // produce better code in some cases.
498     auto dim = shape_.dimensions(i);
499     actual_index.push_back(
500         dim == 1 ? llvm::ConstantInt::get(index[i]->getType(), 0) : index[i]);
501   }
502 
503   // "base_ptr_" has the type of "<ir_type_for_its_shape>*"
504   // (e.g. [3 x [2 x float]]*). Therefore, the address of the indexed element
505   // should be computed by
506   //
507   //   getelementptr base_ptr_, 0, most major index, ..., most minor index
508   CHECK_GT(index.size(), 0);
509   std::vector<llvm::Value*> gep_indices(
510       1, llvm::ConstantInt::get(index[0]->getType(), 0));
511   for (int64_t i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) {
512     int64_t dimension = LayoutUtil::Major(shape_.layout(), i);
513     gep_indices.push_back(actual_index[dimension]);
514   }
515   return b->CreateInBoundsGEP(pointee_type_, base_ptr_, gep_indices,
516                               llvm_ir::AsStringRef(name));
517 }
518 
AnnotateLoadStoreInstructionWithMetadata(llvm::Instruction * instruction) const519 void IrArray::AnnotateLoadStoreInstructionWithMetadata(
520     llvm::Instruction* instruction) const {
521   CHECK(llvm::isa<llvm::LoadInst>(instruction) ||
522         llvm::isa<llvm::StoreInst>(instruction));
523   CHECK(!llvm::isa<llvm::StoreInst>(instruction) || !is_invariant_)
524       << "Trying to create a store to an invariant IRArray.";
525 
526   for (const auto& kind_md_pair : metadata_) {
527     instruction->setMetadata(kind_md_pair.first, kind_md_pair.second);
528   }
529 }
530 
EmitReadArrayElement(const Index & index,llvm::IRBuilder<> * b,absl::string_view name,bool use_linear_index) const531 llvm::Value* IrArray::EmitReadArrayElement(const Index& index,
532                                            llvm::IRBuilder<>* b,
533                                            absl::string_view name,
534                                            bool use_linear_index) const {
535   llvm::Value* element_address =
536       EmitArrayElementAddress(index, b, name, use_linear_index);
537   llvm::LoadInst* load =
538       b->CreateLoad(element_type_, element_address, llvm_ir::AsStringRef(name));
539   AnnotateLoadStoreInstructionWithMetadata(load);
540   return load;
541 }
542 
EmitWriteArrayElement(const Index & index,llvm::Value * value,llvm::IRBuilder<> * b,bool use_linear_index) const543 void IrArray::EmitWriteArrayElement(const Index& index, llvm::Value* value,
544                                     llvm::IRBuilder<>* b,
545                                     bool use_linear_index) const {
546   llvm::Value* element_address =
547       EmitArrayElementAddress(index, b, "", use_linear_index);
548   llvm::StoreInst* store = b->CreateStore(value, element_address);
549   AnnotateLoadStoreInstructionWithMetadata(store);
550 }
551 
CastToShape(const Shape & new_shape,llvm::IRBuilder<> * b) const552 IrArray IrArray::CastToShape(const Shape& new_shape,
553                              llvm::IRBuilder<>* b) const {
554   llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
555   llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(new_shape, module);
556   IrArray new_irarray(
557       b->CreatePointerCast(base_ptr_, new_ir_type->getPointerTo()), new_ir_type,
558       new_shape);
559   new_irarray.metadata_ = metadata_;
560   return new_irarray;
561 }
562 
ShapeIsCompatible(const Shape & a,const Shape & b)563 bool IrArray::Index::ShapeIsCompatible(const Shape& a, const Shape& b) {
564   // Compute strides for two sides of the comparison. Sometimes different shapes
565   // give the same strides:
566   //   [10, 20, 30, 1]{3,2,1,0} vs [10, 20, 1, 30]{3,2,1,0}
567   // which should be considered compatible.
568   const auto get_strides = [](const Shape& shape) {
569     int rank = shape.dimensions().size();
570     int64_t stride = 1;
571     std::vector<int64_t> strides;
572     for (int i = 0; i < rank; i++) {
573       auto dim = shape.dimensions(shape.layout().minor_to_major(i));
574       if (dim != 1) {
575         stride *= dim;
576         strides.push_back(stride);
577       }
578     }
579     return strides;
580   };
581 
582   return get_strides(a) == get_strides(b);
583 }
584 
585 }  // namespace llvm_ir
586 }  // namespace xla
587