• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
17 
18 #include <tuple>
19 #include <vector>
20 
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/Instructions.h"
23 #include "llvm/IR/Value.h"
24 #include "tensorflow/compiler/xla/layout_util.h"
25 #include "tensorflow/compiler/xla/permutation_util.h"
26 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/types.h"
33 
34 namespace xla {
35 namespace llvm_ir {
36 
Index(absl::Span<llvm::Value * const> multidim,llvm::Value * linear,const Shape & shape,llvm::Type * index_type)37 IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
38                       llvm::Value* linear, const Shape& shape,
39                       llvm::Type* index_type)
40     : Index(multidim, shape, index_type) {
41   CHECK_NE(linear, nullptr);
42   linear_ = linear;
43 }
44 
Delinearize(std::vector<llvm::Value * > * multidim,llvm::Value * linear,const Shape & shape,llvm::IRBuilder<> * b) const45 void IrArray::Index::Delinearize(std::vector<llvm::Value*>* multidim,
46                                  llvm::Value* linear, const Shape& shape,
47                                  llvm::IRBuilder<>* b) const {
48   int64 divisor = 1;
49   const Layout& layout = shape.layout();
50   for (int64 i = 0; i < layout.minor_to_major_size(); ++i) {
51     int64 dimension = layout.minor_to_major(i);
52     int64 size_of_current_dimension = shape.dimensions(dimension);
53 
54     // If i is not the last dimension, compute
55     //   (linear_index / divisor) % current_dimension.
56     // If i is the last dimension, we can skip the mod, because we assume that
57     // linear is in bounds.
58     //
59     // TODO(jlebar): We could add bounds checks here and elsewhere in this file,
60     // guarded under some sort of xla-memcheck flag.  This might be particularly
61     // useful because cuda-memcheck can't help us much in XLA: Most of our
62     // memory lives in one big allocation, so cuda-memcheck can't detect
63     // out-of-bounds accesses.
64     auto* quot = b->CreateUDiv(linear, GetConstantWithIndexType(divisor));
65     if (i < layout.minor_to_major_size() - 1) {
66       (*multidim)[dimension] = b->CreateURem(
67           quot, GetConstantWithIndexType(size_of_current_dimension));
68     } else {
69       (*multidim)[dimension] = quot;
70     }
71     divisor *= size_of_current_dimension;
72   }
73 }
74 
Delinearize(std::vector<llvm::Value * > * multidim,llvm::Value * linear,const Shape & shape,absl::Span<llvm::Value * > dynamic_dims,llvm::IRBuilder<> * b) const75 void IrArray::Index::Delinearize(std::vector<llvm::Value*>* multidim,
76                                  llvm::Value* linear, const Shape& shape,
77                                  absl::Span<llvm::Value*> dynamic_dims,
78                                  llvm::IRBuilder<>* b) const {
79   CHECK_EQ(shape.dimensions_size(), dynamic_dims.size());
80   CHECK_EQ(multidim_.size(), shape.rank());
81   llvm::Value* divisor = GetConstantWithIndexType(1);
82   const Layout& layout = shape.layout();
83   for (int64 i = 0; i < layout.minor_to_major_size(); ++i) {
84     int64 dimension = layout.minor_to_major(i);
85 
86     // If i is not the last dimension, compute
87     //   (linear_index / divisor) % current_dimension.
88     // If i is the last dimension, we can skip the mod, because we assume that
89     // linear is in bounds.
90     auto* quot = b->CreateUDiv(linear, divisor, "quot");
91     if (i < layout.minor_to_major_size() - 1) {
92       (*multidim)[dimension] =
93           b->CreateURem(quot, dynamic_dims[dimension], "dim_value");
94       divisor = b->CreateMul(divisor, dynamic_dims[dimension], "divisor");
95     } else {
96       (*multidim)[dimension] = quot;
97     }
98   }
99 }
100 
Index(llvm::Value * linear,const Shape & shape,llvm::IRBuilder<> * b)101 IrArray::Index::Index(llvm::Value* linear, const Shape& shape,
102                       llvm::IRBuilder<>* b)
103     : multidim_(shape.rank()),
104       linear_(linear),
105       layout_(shape.layout()),
106       dims_(shape.dimensions().begin(), shape.dimensions().end()) {
107   CHECK_NE(linear, nullptr);
108   index_type_ = linear->getType();
109   CHECK(LayoutUtil::HasLayout(shape))
110       << "Shape " << ShapeUtil::HumanStringWithLayout(shape)
111       << " should have a layout.";
112   Delinearize(&multidim_, linear, shape, b);
113 }
114 
Index(llvm::Value * linear,const Shape & shape,absl::Span<llvm::Value * > dynamic_dims,llvm::IRBuilder<> * b)115 IrArray::Index::Index(llvm::Value* linear, const Shape& shape,
116                       absl::Span<llvm::Value*> dynamic_dims,
117                       llvm::IRBuilder<>* b)
118     : multidim_(shape.rank()),
119       linear_(linear),
120       layout_(shape.layout()),
121       dims_(shape.dimensions().begin(), shape.dimensions().end()) {
122   CHECK_NE(linear, nullptr);
123   index_type_ = linear->getType();
124   CHECK(LayoutUtil::HasLayout(shape))
125       << "Shape " << ShapeUtil::HumanStringWithLayout(shape)
126       << " should have a layout.";
127   Delinearize(&multidim_, linear, shape, dynamic_dims, b);
128 }
129 
Index(absl::Span<llvm::Value * const> multidim,absl::Span<int64 const> dimensions,llvm::Type * index_type)130 IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
131                       absl::Span<int64 const> dimensions,
132                       llvm::Type* index_type)
133     : Index(multidim, ShapeUtil::MakeShape(/*arbitrary*/ PRED, dimensions),
134             index_type) {}
135 
Index(absl::Span<llvm::Value * const> multidim,const Shape & shape,llvm::Type * index_type)136 IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
137                       const Shape& shape, llvm::Type* index_type)
138     : multidim_(multidim.begin(), multidim.end()),
139       linear_(nullptr),
140       layout_(shape.layout()),
141       dims_(shape.dimensions().begin(), shape.dimensions().end()),
142       index_type_(index_type) {
143   CHECK_NE(index_type_, nullptr);
144   CHECK_EQ(shape.dimensions_size(), multidim.size());
145   for (const auto* dim : multidim) {
146     CHECK_NE(dim, nullptr);
147   }
148   CHECK(LayoutUtil::HasLayout(shape))
149       << "Shape " << ShapeUtil::HumanStringWithLayout(shape)
150       << " should have a layout.";
151 }
152 
IrArray(llvm::Value * base_ptr,Shape shape)153 IrArray::IrArray(llvm::Value* base_ptr, Shape shape)
154     : base_ptr_(base_ptr), shape_(std::move(shape)) {
155   TF_CHECK_OK(ShapeUtil::ValidateShape(shape));
156   CHECK(base_ptr_->getType()->isPointerTy());
157   int depth = 0;
158   element_type_ =
159       llvm::cast<llvm::PointerType>(base_ptr_->getType())->getElementType();
160   while (llvm::ArrayType* array_type =
161              llvm::dyn_cast<llvm::ArrayType>(element_type_)) {
162     element_type_ = array_type->getElementType();
163     ++depth;
164   }
165 
166   if (!shape_.IsArray() || ShapeUtil::IsScalar(shape_)) {
167     DCHECK(depth == 1 || depth == 0) << depth;
168   } else {
169     DCHECK_EQ(depth, shape_.rank()) << shape.ShortDebugString();
170   }
171 }
172 
173 // Returns whether the given linear index is valid on the given shape.
LinearValidOnShape(const Shape & a) const174 bool IrArray::Index::LinearValidOnShape(const Shape& a) const {
175   auto b = ShapeUtil::MakeShape(a.element_type(), dims_);
176   *b.mutable_layout() = layout_;
177   return linear_ != nullptr &&
178          ShapeUtil::ElementsIn(a) == ShapeUtil::ElementsIn(b) &&
179          ShapeUtil::ReshapeIsBitcast(a, b);
180 }
181 
SourceIndexOfReshape(const Shape & output_shape,const Shape & input_shape,llvm::IRBuilder<> * builder) const182 IrArray::Index IrArray::Index::SourceIndexOfReshape(
183     const Shape& output_shape, const Shape& input_shape,
184     llvm::IRBuilder<>* builder) const {
185   CHECK_EQ(multidim_.size(), output_shape.rank());
186   std::vector<llvm::Value*> source_multidim_index(
187       input_shape.rank(), llvm::UndefValue::get(index_type_));
188   auto trivial_reshape =
189       ShapeUtil::InsertedOrDeleted1SizedDimensions(input_shape, output_shape);
190   if (std::get<0>(trivial_reshape)) {
191     // The 1-sized dimensions which only appear in 'input_shape'.
192     auto deleted_dims_indices = std::get<1>(trivial_reshape);
193     // The 1-sized dimensions which only appear in 'output_shape'.
194     auto inserted_dims_indices = std::get<2>(trivial_reshape);
195 
196     // This is a two-way merge of 'deleted_dims_indices' with indexing into
197     // 'source_multidim_index', and a two-way merge of 'inserted_dims_indices'
198     // with indexing into 'multidim_'. When we find a dimension in
199     // 'source_multidim_index' which does not belong to 'deleted_dims_indices',
200     // we retrieve the corresponding value from 'multidim_' (skipping any
201     // indices that appear in 'inserted_dims_indices').
202     for (int64 i = 0, j = 0, k = 0, l = 0; i < source_multidim_index.size();
203          ++i) {
204       if (j == deleted_dims_indices.size() || deleted_dims_indices[j] > i) {
205         // This is a dimension that was preserved. Take the matching value from
206         // multidim_.
207         while (l < inserted_dims_indices.size() &&
208                inserted_dims_indices[l] == k) {
209           // Skip 1-sized dimensions.
210           ++k;
211           ++l;
212         }
213         source_multidim_index[i] = multidim_[k];
214         ++k;
215       } else {
216         // This is a 1-sized dimension that only appears in the operand.
217         source_multidim_index[i] = GetConstantWithIndexType(0);
218         ++j;
219       }
220     }
221   } else {
222     const auto common_factors =
223         CommonFactors(AsInt64Slice(input_shape.dimensions()),
224                       AsInt64Slice(output_shape.dimensions()));
225     // We compute the source indices in each common factor from only the target
226     // indices in the same common factor.
227     for (ssize_t k = common_factors.size() - 2; k >= 0; --k) {
228       absl::Span<int64 const> dimensions =
229           AsInt64Slice(output_shape.dimensions())
230               .subspan(common_factors[k].second,
231                        common_factors[k + 1].second - common_factors[k].second);
232       llvm::Value* logical_linear_index =
233           Index(absl::Span<llvm::Value* const>(multidim_).subspan(
234                     common_factors[k].second,
235                     common_factors[k + 1].second - common_factors[k].second),
236                 dimensions, index_type_)
237               .Linearize(dimensions, builder);
238       // Delinearizes logical_linear_index for the source array in row-major
239       // collapsed order. The first rank-1 indices are the remainder of the
240       // linear index by each dimension size.
241       for (int64 i = common_factors[k + 1].first - 1;
242            i >= common_factors[k].first; --i) {
243         llvm::Value* divisor =
244             GetConstantWithIndexType(input_shape.dimensions(i));
245         if (input_shape.dimensions(i) == 1) {
246           source_multidim_index[i] = GetConstantWithIndexType(0);
247         } else if (i == common_factors[k].first) {
248           source_multidim_index[i] = logical_linear_index;
249         } else {
250           source_multidim_index[i] =
251               builder->CreateURem(logical_linear_index, divisor);
252         }
253         logical_linear_index =
254             builder->CreateUDiv(logical_linear_index, divisor);
255       }
256     }
257   }
258 
259   if (linear() != nullptr && LayoutUtil::HasLayout(input_shape) &&
260       LayoutUtil::HasLayout(output_shape) &&
261       ShapeUtil::ReshapeIsBitcast(input_shape, output_shape)) {
262     return Index(source_multidim_index, linear(), input_shape, index_type_);
263   }
264   return Index(source_multidim_index, input_shape, index_type_);
265 }
266 
SourceIndexOfSlice(const Shape & operand_shape,absl::Span<const int64> starts,absl::Span<const int64> strides,llvm::IRBuilder<> * builder) const267 IrArray::Index IrArray::Index::SourceIndexOfSlice(
268     const Shape& operand_shape, absl::Span<const int64> starts,
269     absl::Span<const int64> strides, llvm::IRBuilder<>* builder) const {
270   std::vector<llvm::Value*> source_multi_index(multidim_.size());
271   for (int i = 0; i < multidim_.size(); ++i) {
272     int64 stride = strides[i];
273     if (stride != 1) {
274       source_multi_index[i] = builder->CreateAdd(
275           builder->CreateMul(multidim_[i], GetConstantWithIndexType(stride)),
276           GetConstantWithIndexType(starts[i]));
277     } else {
278       source_multi_index[i] =
279           builder->CreateAdd(multidim_[i], GetConstantWithIndexType(starts[i]));
280     }
281   }
282   return Index(source_multi_index, operand_shape, index_type_);
283 }
284 
SourceIndexOfTranspose(const Shape & shape,const Shape & operand_shape,absl::Span<const int64> dimension_mapping) const285 IrArray::Index IrArray::Index::SourceIndexOfTranspose(
286     const Shape& shape, const Shape& operand_shape,
287     absl::Span<const int64> dimension_mapping) const {
288   std::vector<llvm::Value*> operand_multidim_index =
289       PermuteInverse(multidim(), dimension_mapping);
290 
291   if (linear() != nullptr && LayoutUtil::HasLayout(operand_shape) &&
292       LayoutUtil::HasLayout(shape) &&
293       ShapeUtil::TransposeIsBitcast(operand_shape, shape, dimension_mapping)) {
294     return Index(operand_multidim_index, linear(), operand_shape, index_type_);
295   }
296 
297   return Index(operand_multidim_index, operand_shape, index_type_);
298 }
299 
SourceIndexOfBitcast(const Shape & shape,const Shape & operand_shape,llvm::IRBuilder<> * builder) const300 IrArray::Index IrArray::Index::SourceIndexOfBitcast(
301     const Shape& shape, const Shape& operand_shape,
302     llvm::IRBuilder<>* builder) const {
303   CHECK(LayoutUtil::HasLayout(shape) && LayoutUtil::HasLayout(operand_shape));
304 
305   // In case the bitcast is just a reshape, we can use SourceIndexOfReshape()
306   // instead. This will reuse linear() if possible, so we don't have to build a
307   // new 'linear_index'.
308   if (ShapeUtil::ReshapeIsBitcast(operand_shape, shape)) {
309     return SourceIndexOfReshape(shape, operand_shape, builder);
310   }
311 
312   // If we have a linear index, we can definitely use it because we know the
313   // operation is a bitcast. This will recompute the multi-dimensional index for
314   // the operand based on the linear index.
315   if (linear() != nullptr) {
316     return Index(linear(), operand_shape, builder);
317   }
318 
319   // First linearize the index coming from the output of the bitcast. We want
320   // the physical index of the element in the buffer. This is like Linearize,
321   // but takes the layout into account.
322   int64 scale = 1;
323   llvm::Value* linear_index = GetConstantWithIndexType(0);
324   for (auto dimension : LayoutUtil::MinorToMajor(shape)) {
325     linear_index = builder->CreateAdd(
326         linear_index,
327         builder->CreateMul(multidim_[dimension],
328                            GetConstantWithIndexType(scale), "",
329                            /*HasNUW=*/true, /*HasNSW=*/true),
330         "", /*HasNUW=*/true, /*HasNSW=*/true);
331     scale *= shape.dimensions(dimension);
332   }
333 
334   return Index(linear_index, operand_shape, builder);
335 }
336 
SourceIndexOfBroadcast(const Shape & shape,const Shape & operand_shape,absl::Span<const int64> dimension_mapping,llvm::IRBuilder<> * builder) const337 IrArray::Index IrArray::Index::SourceIndexOfBroadcast(
338     const Shape& shape, const Shape& operand_shape,
339     absl::Span<const int64> dimension_mapping,
340     llvm::IRBuilder<>* builder) const {
341   int64 rank = operand_shape.rank();
342   std::vector<llvm::Value*> source_index(rank);
343   for (int64 i = 0; i < rank; ++i) {
344     source_index[i] = multidim_[dimension_mapping[i]];
345   }
346   if (linear_ == nullptr || !LayoutUtil::HasLayout(operand_shape) ||
347       !LayoutUtil::HasLayout(shape)) {
348     return Index(source_index, operand_shape, index_type_);
349   }
350   // High-level idea: we can reuse the linear index if the broadcasted
351   // dimensions are contiguous, and this part of the operation is a bitcast.
352   // The other dimensions can be masked out with a div and a mod operation.
353   std::vector<int64> logical_to_physical =
354       LayoutUtil::MakeLogicalToPhysical(shape.layout());
355   int64 output_rank = shape.rank();
356   // The minimum physical dimension that is broadcasted.
357   int64 min_broadcasted_dimension = output_rank;
358   // The maximum physical dimension that is broadcasted.
359   int64 max_broadcasted_dimension = -1;
360   for (int64 i = 0; i < rank; ++i) {
361     int64 physical_dim = logical_to_physical[dimension_mapping[i]];
362     min_broadcasted_dimension =
363         std::min(min_broadcasted_dimension, physical_dim);
364     max_broadcasted_dimension =
365         std::max(max_broadcasted_dimension, physical_dim);
366   }
367   bool contiguous_broadcast_dimensions =
368       max_broadcasted_dimension - min_broadcasted_dimension == rank - 1;
369   if (!contiguous_broadcast_dimensions) {
370     return Index(source_index, operand_shape, index_type_);
371   }
372   // Check if the mapped dimensions are a bitcast.
373   std::vector<int64> operand_logical_to_physical =
374       LayoutUtil::MakeLogicalToPhysical(operand_shape.layout());
375   for (int64 i = 0; i < rank; ++i) {
376     if (operand_logical_to_physical[i] !=
377         logical_to_physical[dimension_mapping[i]] - min_broadcasted_dimension) {
378       return Index(source_index, operand_shape, index_type_);
379     }
380   }
381   llvm::Value* linear = linear_;
382   int64 divisor = 1;
383   for (int64 i = max_broadcasted_dimension + 1; i < output_rank; ++i) {
384     divisor *= shape.dimensions(LayoutUtil::Major(shape.layout(), i));
385   }
386   if (divisor > 1) {
387     linear = builder->CreateUDiv(linear, GetConstantWithIndexType(divisor));
388   }
389   if (min_broadcasted_dimension > 0) {
390     int64 mod = 1;
391     for (int64 i = min_broadcasted_dimension; i <= max_broadcasted_dimension;
392          ++i) {
393       mod *= shape.dimensions(LayoutUtil::Major(shape.layout(), i));
394     }
395     linear = builder->CreateURem(linear, GetConstantWithIndexType(mod));
396   }
397   return Index(source_index, linear, operand_shape, index_type_);
398 }
399 
Linearize(absl::Span<const int64> dimensions,llvm::IRBuilder<> * builder) const400 llvm::Value* IrArray::Index::Linearize(absl::Span<const int64> dimensions,
401                                        llvm::IRBuilder<>* builder) const {
402   // Each dimension is multiplied by the product of the sizes of all
403   // earlier dimensions and added to the accumulator logical_linear_index.
404   CHECK_EQ(size(), dimensions.size());
405   llvm::Value* logical_linear_index = GetConstantWithIndexType(0);
406   int64 multiplier = 1;
407   for (ssize_t i = size() - 1; i >= 0; --i) {
408     llvm::Value* addend =
409         builder->CreateMul((*this)[i], GetConstantWithIndexType(multiplier), "",
410                            /*HasNUW=*/true, /*HasNSW=*/true);
411     addend = builder->CreateZExtOrTrunc(addend, index_type_);
412     logical_linear_index = builder->CreateAdd(logical_linear_index, addend, "",
413                                               /*HasNUW=*/true, /*HasNSW=*/true);
414     multiplier *= dimensions[i];
415   }
416   return logical_linear_index;
417 }
418 
Linearize(const std::vector<llvm::Value * > & dynamic_dims,llvm::IRBuilder<> * builder) const419 llvm::Value* IrArray::Index::Linearize(
420     const std::vector<llvm::Value*>& dynamic_dims,
421     llvm::IRBuilder<>* builder) const {
422   // Each dimension is multiplied by the product of the sizes of all
423   // earlier dimensions and added to the accumulator logical_linear_index.
424   CHECK_EQ(size(), dynamic_dims.size());
425   llvm::Value* logical_linear_index = GetConstantWithIndexType(0);
426   llvm::Value* multiplier = GetConstantWithIndexType(1);
427   for (ssize_t i = size() - 1; i >= 0; --i) {
428     llvm::Value* addend = builder->CreateMul((*this)[i], multiplier, "",
429                                              /*HasNUW=*/true, /*HasNSW=*/true);
430     addend = builder->CreateZExtOrTrunc(addend, index_type_);
431     logical_linear_index = builder->CreateAdd(logical_linear_index, addend, "",
432                                               /*HasNUW=*/true, /*HasNSW=*/true);
433     if (i) {
434       multiplier = builder->CreateMul(multiplier, dynamic_dims[i],
435                                       /*Name=*/"multiplier");
436     }
437   }
438   return logical_linear_index;
439 }
440 
EmitArrayElementAddress(const IrArray::Index & index,llvm::IRBuilder<> * b,absl::string_view name,bool use_linear_index) const441 llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index,
442                                               llvm::IRBuilder<>* b,
443                                               absl::string_view name,
444                                               bool use_linear_index) const {
445   if (ShapeUtil::IsScalar(shape_)) {
446     // Special handling of scalars: a scalar pretends to have the same value for
447     // every index, thus effectively implementing broadcasting of its value
448     // over higher-rank arrays.
449     return base_ptr_;
450   }
451   CHECK_EQ(index.size(), shape_.rank());
452   CHECK(index.ShapeIsCompatible(shape_));
453 
454   if (use_linear_index && index.LinearValidOnShape(shape_)) {
455     llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
456     return b->CreateInBoundsGEP(
457         b->CreateBitCast(base_ptr_,
458                          PrimitiveTypeToIrType(shape_.element_type(), module)
459                              ->getPointerTo()),
460         {index.linear()}, llvm_ir::AsStringRef(name));
461   }
462 
463   std::vector<llvm::Value*> actual_index;
464   for (int64 i = 0; i < index.size(); ++i) {
465     // When dimension i is of size 1, LLVM optimization is able to replace
466     // index[i] with 0. However, setting index[i] to 0 here still allows LLVM to
467     // produce better code in some cases.
468     auto dim = shape_.dimensions(i);
469     actual_index.push_back(
470         dim == 1 ? llvm::ConstantInt::get(index[i]->getType(), 0) : index[i]);
471   }
472 
473   // "base_ptr_" has the type of "<ir_type_for_its_shape>*"
474   // (e.g. [3 x [2 x float]]*). Therefore, the address of the indexed element
475   // should be computed by
476   //
477   //   getelementptr base_ptr_, 0, most major index, ..., most minor index
478   CHECK_GT(index.size(), 0);
479   std::vector<llvm::Value*> gep_indices(
480       1, llvm::ConstantInt::get(index[0]->getType(), 0));
481   for (int64 i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) {
482     int64 dimension = LayoutUtil::Major(shape_.layout(), i);
483     gep_indices.push_back(actual_index[dimension]);
484   }
485   return b->CreateInBoundsGEP(base_ptr_, gep_indices,
486                               llvm_ir::AsStringRef(name));
487 }
488 
AnnotateLoadStoreInstructionWithMetadata(llvm::Instruction * instruction) const489 void IrArray::AnnotateLoadStoreInstructionWithMetadata(
490     llvm::Instruction* instruction) const {
491   CHECK(llvm::isa<llvm::LoadInst>(instruction) ||
492         llvm::isa<llvm::StoreInst>(instruction));
493   CHECK(!llvm::isa<llvm::StoreInst>(instruction) || !is_invariant_)
494       << "Trying to create a store to an invariant IRArray.";
495 
496   for (const auto& kind_md_pair : metadata_) {
497     instruction->setMetadata(kind_md_pair.first, kind_md_pair.second);
498   }
499 }
500 
EmitReadArrayElement(const Index & index,llvm::IRBuilder<> * b,absl::string_view name,bool use_linear_index) const501 llvm::Value* IrArray::EmitReadArrayElement(const Index& index,
502                                            llvm::IRBuilder<>* b,
503                                            absl::string_view name,
504                                            bool use_linear_index) const {
505   llvm::Value* element_address =
506       EmitArrayElementAddress(index, b, name, use_linear_index);
507   llvm::LoadInst* load = b->CreateLoad(element_address, name.data());
508   AnnotateLoadStoreInstructionWithMetadata(load);
509   return load;
510 }
511 
EmitWriteArrayElement(const Index & index,llvm::Value * value,llvm::IRBuilder<> * b,bool use_linear_index) const512 void IrArray::EmitWriteArrayElement(const Index& index, llvm::Value* value,
513                                     llvm::IRBuilder<>* b,
514                                     bool use_linear_index) const {
515   llvm::Value* element_address =
516       EmitArrayElementAddress(index, b, "", use_linear_index);
517   llvm::StoreInst* store = b->CreateStore(value, element_address);
518   AnnotateLoadStoreInstructionWithMetadata(store);
519 }
520 
CastToShape(const Shape & new_shape,llvm::IRBuilder<> * b) const521 IrArray IrArray::CastToShape(const Shape& new_shape,
522                              llvm::IRBuilder<>* b) const {
523   llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
524   llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(new_shape, module);
525   IrArray new_irarray(
526       b->CreatePointerCast(base_ptr_, new_ir_type->getPointerTo()), new_shape);
527   new_irarray.metadata_ = metadata_;
528   return new_irarray;
529 }
530 
ShapeIsCompatible(const Shape & a,const Shape & b)531 bool IrArray::Index::ShapeIsCompatible(const Shape& a, const Shape& b) {
532   // Compute strides for two sides of the comparison. Sometimes different shapes
533   // give the same strides:
534   //   [10, 20, 30, 1]{3,2,1,0} vs [10, 20, 1, 30]{3,2,1,0}
535   // which should be considered compatible.
536   const auto get_strides = [](const Shape& shape) {
537     int rank = shape.dimensions().size();
538     int64 stride = 1;
539     std::vector<int64> strides;
540     for (int i = 0; i < rank; i++) {
541       auto dim = shape.dimensions(shape.layout().minor_to_major(i));
542       if (dim != 1) {
543         stride *= dim;
544         strides.push_back(stride);
545       }
546     }
547     return strides;
548   };
549 
550   return get_strides(a) == get_strides(b);
551 }
552 
553 }  // namespace llvm_ir
554 }  // namespace xla
555