• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
17 
18 #include <tuple>
19 #include <vector>
20 
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/Instructions.h"
23 #include "llvm/IR/Value.h"
24 #include "tensorflow/compiler/xla/layout_util.h"
25 #include "tensorflow/compiler/xla/permutation_util.h"
26 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
27 #include "tensorflow/compiler/xla/shape_util.h"
28 #include "tensorflow/compiler/xla/statusor.h"
29 #include "tensorflow/compiler/xla/util.h"
30 #include "tensorflow/compiler/xla/xla_data.pb.h"
31 #include "tensorflow/core/platform/logging.h"
32 #include "tensorflow/core/platform/types.h"
33 
34 namespace xla {
35 namespace llvm_ir {
36 
Index(absl::Span<llvm::Value * const> multidim,llvm::Value * linear,const Shape & shape,llvm::Type * index_type)37 IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
38                       llvm::Value* linear, const Shape& shape,
39                       llvm::Type* index_type)
40     : Index(multidim, shape, index_type) {
41   CHECK_NE(linear, nullptr);
42   linear_ = linear;
43 }
44 
Delinearize(std::vector<llvm::Value * > * multidim,llvm::Value * linear,const Shape & shape,llvm::IRBuilder<> * b) const45 void IrArray::Index::Delinearize(std::vector<llvm::Value*>* multidim,
46                                  llvm::Value* linear, const Shape& shape,
47                                  llvm::IRBuilder<>* b) const {
48   int64_t divisor = 1;
49   const Layout& layout = shape.layout();
50   for (int64_t i = 0; i < layout.minor_to_major_size(); ++i) {
51     int64_t dimension = layout.minor_to_major(i);
52     int64_t size_of_current_dimension = shape.dimensions(dimension);
53 
54     // If i is not the last dimension, compute
55     //   (linear_index / divisor) % current_dimension.
56     // If i is the last dimension, we can skip the mod, because we assume that
57     // linear is in bounds.
58     //
59     // TODO(jlebar): We could add bounds checks here and elsewhere in this file,
60     // guarded under some sort of xla-memcheck flag.  This might be particularly
61     // useful because cuda-memcheck can't help us much in XLA: Most of our
62     // memory lives in one big allocation, so cuda-memcheck can't detect
63     // out-of-bounds accesses.
64     auto* quot = b->CreateUDiv(linear, GetConstantWithIndexType(divisor));
65     if (i < layout.minor_to_major_size() - 1) {
66       (*multidim)[dimension] = b->CreateURem(
67           quot, GetConstantWithIndexType(size_of_current_dimension));
68     } else {
69       (*multidim)[dimension] = quot;
70     }
71     divisor *= size_of_current_dimension;
72   }
73 }
74 
Delinearize(std::vector<llvm::Value * > * multidim,llvm::Value * linear,const Shape & shape,absl::Span<llvm::Value * > dynamic_dims,llvm::IRBuilder<> * b) const75 void IrArray::Index::Delinearize(std::vector<llvm::Value*>* multidim,
76                                  llvm::Value* linear, const Shape& shape,
77                                  absl::Span<llvm::Value*> dynamic_dims,
78                                  llvm::IRBuilder<>* b) const {
79   CHECK_EQ(shape.dimensions_size(), dynamic_dims.size());
80   CHECK_EQ(multidim_.size(), shape.rank());
81   llvm::Value* divisor = GetConstantWithIndexType(1);
82   const Layout& layout = shape.layout();
83   for (int64_t i = 0; i < layout.minor_to_major_size(); ++i) {
84     int64_t dimension = layout.minor_to_major(i);
85 
86     // If i is not the last dimension, compute
87     //   (linear_index / divisor) % current_dimension.
88     // If i is the last dimension, we can skip the mod, because we assume that
89     // linear is in bounds.
90     auto* quot = b->CreateUDiv(linear, divisor, "quot");
91     if (i < layout.minor_to_major_size() - 1) {
92       (*multidim)[dimension] =
93           b->CreateURem(quot, dynamic_dims[dimension], "dim_value");
94       divisor = b->CreateMul(divisor, dynamic_dims[dimension], "divisor");
95     } else {
96       (*multidim)[dimension] = quot;
97     }
98   }
99 }
100 
Index(llvm::Value * linear,const Shape & shape,llvm::IRBuilder<> * b)101 IrArray::Index::Index(llvm::Value* linear, const Shape& shape,
102                       llvm::IRBuilder<>* b)
103     : multidim_(shape.rank()),
104       linear_(linear),
105       layout_(shape.layout()),
106       dims_(shape.dimensions().begin(), shape.dimensions().end()) {
107   CHECK_NE(linear, nullptr);
108   index_type_ = linear->getType();
109   CHECK(LayoutUtil::HasLayout(shape))
110       << "Shape " << ShapeUtil::HumanStringWithLayout(shape)
111       << " should have a layout.";
112   Delinearize(&multidim_, linear, shape, b);
113 }
114 
Index(llvm::Value * linear,absl::Span<llvm::Value * const> multidim,const Shape & shape,llvm::IRBuilder<> * b)115 IrArray::Index::Index(llvm::Value* linear,
116                       absl::Span<llvm::Value* const> multidim,
117                       const Shape& shape, llvm::IRBuilder<>* b)
118     : multidim_(shape.rank()),
119       linear_(linear),
120       layout_(shape.layout()),
121       dims_(shape.dimensions().begin(), shape.dimensions().end()) {
122   CHECK_NE(linear, nullptr);
123   index_type_ = linear->getType();
124   CHECK_EQ(multidim.size(), shape.rank());
125   for (auto dim : multidim) {
126     if (dim) {
127       CHECK_EQ(dim->getType(), index_type_);
128     }
129   }
130   CHECK(LayoutUtil::HasLayout(shape))
131       << "Shape " << ShapeUtil::HumanStringWithLayout(shape)
132       << " should have a layout.";
133   Delinearize(&multidim_, linear, shape, b);
134   for (int i = 0; i < multidim.size(); ++i) {
135     if (multidim[i] != nullptr) {
136       multidim_[i] = multidim[i];
137     }
138   }
139 }
140 
Index(llvm::Value * linear,const Shape & shape,absl::Span<llvm::Value * > dynamic_dims,llvm::IRBuilder<> * b)141 IrArray::Index::Index(llvm::Value* linear, const Shape& shape,
142                       absl::Span<llvm::Value*> dynamic_dims,
143                       llvm::IRBuilder<>* b)
144     : multidim_(shape.rank()),
145       linear_(linear),
146       layout_(shape.layout()),
147       dims_(shape.dimensions().begin(), shape.dimensions().end()) {
148   CHECK_NE(linear, nullptr);
149   index_type_ = linear->getType();
150   CHECK(LayoutUtil::HasLayout(shape))
151       << "Shape " << ShapeUtil::HumanStringWithLayout(shape)
152       << " should have a layout.";
153   Delinearize(&multidim_, linear, shape, dynamic_dims, b);
154 }
155 
Index(absl::Span<llvm::Value * const> multidim,absl::Span<int64 const> dimensions,llvm::Type * index_type)156 IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
157                       absl::Span<int64 const> dimensions,
158                       llvm::Type* index_type)
159     : Index(multidim, ShapeUtil::MakeShape(/*arbitrary*/ PRED, dimensions),
160             index_type) {}
161 
Index(absl::Span<llvm::Value * const> multidim,const Shape & shape,llvm::Type * index_type)162 IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
163                       const Shape& shape, llvm::Type* index_type)
164     : multidim_(multidim.begin(), multidim.end()),
165       linear_(nullptr),
166       layout_(shape.layout()),
167       dims_(shape.dimensions().begin(), shape.dimensions().end()),
168       index_type_(index_type) {
169   CHECK_NE(index_type_, nullptr);
170   CHECK_EQ(shape.dimensions_size(), multidim.size());
171   for (const auto* dim : multidim) {
172     CHECK_NE(dim, nullptr);
173   }
174   CHECK(LayoutUtil::HasLayout(shape))
175       << "Shape " << ShapeUtil::HumanStringWithLayout(shape)
176       << " should have a layout.";
177 }
178 
IrArray(llvm::Value * base_ptr,Shape shape)179 IrArray::IrArray(llvm::Value* base_ptr, Shape shape)
180     : base_ptr_(base_ptr), shape_(std::move(shape)) {
181   TF_CHECK_OK(ShapeUtil::ValidateShape(shape));
182   CHECK(base_ptr_->getType()->isPointerTy());
183   int depth = 0;
184   element_type_ =
185       llvm::cast<llvm::PointerType>(base_ptr_->getType())->getElementType();
186   while (llvm::ArrayType* array_type =
187              llvm::dyn_cast<llvm::ArrayType>(element_type_)) {
188     element_type_ = array_type->getElementType();
189     ++depth;
190   }
191 
192   if (!shape_.IsArray() || ShapeUtil::IsScalar(shape_)) {
193     DCHECK(depth == 1 || depth == 0) << depth;
194   } else {
195     DCHECK_EQ(depth, shape_.rank()) << shape.ShortDebugString();
196   }
197 }
198 
199 // Returns whether the given linear index is valid on the given shape.
LinearValidOnShape(const Shape & a) const200 bool IrArray::Index::LinearValidOnShape(const Shape& a) const {
201   auto b = ShapeUtil::MakeShape(a.element_type(), dims_);
202   *b.mutable_layout() = layout_;
203   return linear_ != nullptr &&
204          ShapeUtil::ElementsIn(a) == ShapeUtil::ElementsIn(b) &&
205          ShapeUtil::ReshapeIsBitcast(a, b);
206 }
207 
SourceIndexOfReshape(const Shape & output_shape,const Shape & input_shape,llvm::IRBuilder<> * builder) const208 IrArray::Index IrArray::Index::SourceIndexOfReshape(
209     const Shape& output_shape, const Shape& input_shape,
210     llvm::IRBuilder<>* builder) const {
211   CHECK_EQ(multidim_.size(), output_shape.rank());
212   std::vector<llvm::Value*> source_multidim_index(
213       input_shape.rank(), llvm::UndefValue::get(index_type_));
214   auto trivial_reshape =
215       ShapeUtil::InsertedOrDeleted1SizedDimensions(input_shape, output_shape);
216   if (std::get<0>(trivial_reshape)) {
217     // The 1-sized dimensions which only appear in 'input_shape'.
218     auto deleted_dims_indices = std::get<1>(trivial_reshape);
219     // The 1-sized dimensions which only appear in 'output_shape'.
220     auto inserted_dims_indices = std::get<2>(trivial_reshape);
221 
222     // This is a two-way merge of 'deleted_dims_indices' with indexing into
223     // 'source_multidim_index', and a two-way merge of 'inserted_dims_indices'
224     // with indexing into 'multidim_'. When we find a dimension in
225     // 'source_multidim_index' which does not belong to 'deleted_dims_indices',
226     // we retrieve the corresponding value from 'multidim_' (skipping any
227     // indices that appear in 'inserted_dims_indices').
228     for (int64_t i = 0, j = 0, k = 0, l = 0; i < source_multidim_index.size();
229          ++i) {
230       if (j == deleted_dims_indices.size() || deleted_dims_indices[j] > i) {
231         // This is a dimension that was preserved. Take the matching value from
232         // multidim_.
233         while (l < inserted_dims_indices.size() &&
234                inserted_dims_indices[l] == k) {
235           // Skip 1-sized dimensions.
236           ++k;
237           ++l;
238         }
239         source_multidim_index[i] = multidim_[k];
240         ++k;
241       } else {
242         // This is a 1-sized dimension that only appears in the operand.
243         source_multidim_index[i] = GetConstantWithIndexType(0);
244         ++j;
245       }
246     }
247   } else {
248     const auto common_factors =
249         CommonFactors(AsInt64Slice(input_shape.dimensions()),
250                       AsInt64Slice(output_shape.dimensions()));
251     // We compute the source indices in each common factor from only the target
252     // indices in the same common factor.
253     for (ssize_t k = common_factors.size() - 2; k >= 0; --k) {
254       absl::Span<int64 const> dimensions =
255           AsInt64Slice(output_shape.dimensions())
256               .subspan(common_factors[k].second,
257                        common_factors[k + 1].second - common_factors[k].second);
258       llvm::Value* logical_linear_index =
259           Index(absl::Span<llvm::Value* const>(multidim_).subspan(
260                     common_factors[k].second,
261                     common_factors[k + 1].second - common_factors[k].second),
262                 dimensions, index_type_)
263               .Linearize(dimensions, builder);
264       // Delinearizes logical_linear_index for the source array in row-major
265       // collapsed order. The first rank-1 indices are the remainder of the
266       // linear index by each dimension size.
267       for (int64_t i = common_factors[k + 1].first - 1;
268            i >= common_factors[k].first; --i) {
269         llvm::Value* divisor =
270             GetConstantWithIndexType(input_shape.dimensions(i));
271         if (input_shape.dimensions(i) == 1) {
272           source_multidim_index[i] = GetConstantWithIndexType(0);
273         } else if (i == common_factors[k].first) {
274           source_multidim_index[i] = logical_linear_index;
275         } else {
276           source_multidim_index[i] =
277               builder->CreateURem(logical_linear_index, divisor);
278         }
279         logical_linear_index =
280             builder->CreateUDiv(logical_linear_index, divisor);
281       }
282     }
283   }
284 
285   if (linear() != nullptr && LayoutUtil::HasLayout(input_shape) &&
286       LayoutUtil::HasLayout(output_shape) &&
287       ShapeUtil::ReshapeIsBitcast(input_shape, output_shape)) {
288     return Index(source_multidim_index, linear(), input_shape, index_type_);
289   }
290   return Index(source_multidim_index, input_shape, index_type_);
291 }
292 
SourceIndexOfSlice(const Shape & operand_shape,absl::Span<const int64> starts,absl::Span<const int64> strides,llvm::IRBuilder<> * builder) const293 IrArray::Index IrArray::Index::SourceIndexOfSlice(
294     const Shape& operand_shape, absl::Span<const int64> starts,
295     absl::Span<const int64> strides, llvm::IRBuilder<>* builder) const {
296   std::vector<llvm::Value*> source_multi_index(multidim_.size());
297   for (int i = 0; i < multidim_.size(); ++i) {
298     int64_t stride = strides[i];
299     if (stride != 1) {
300       source_multi_index[i] = builder->CreateAdd(
301           builder->CreateMul(multidim_[i], GetConstantWithIndexType(stride)),
302           GetConstantWithIndexType(starts[i]));
303     } else {
304       source_multi_index[i] =
305           builder->CreateAdd(multidim_[i], GetConstantWithIndexType(starts[i]));
306     }
307   }
308   return Index(source_multi_index, operand_shape, index_type_);
309 }
310 
SourceIndexOfTranspose(const Shape & shape,const Shape & operand_shape,absl::Span<const int64> dimension_mapping) const311 IrArray::Index IrArray::Index::SourceIndexOfTranspose(
312     const Shape& shape, const Shape& operand_shape,
313     absl::Span<const int64> dimension_mapping) const {
314   std::vector<llvm::Value*> operand_multidim_index =
315       PermuteInverse(multidim(), dimension_mapping);
316 
317   if (linear() != nullptr && LayoutUtil::HasLayout(operand_shape) &&
318       LayoutUtil::HasLayout(shape) &&
319       ShapeUtil::TransposeIsBitcast(operand_shape, shape, dimension_mapping)) {
320     return Index(operand_multidim_index, linear(), operand_shape, index_type_);
321   }
322 
323   return Index(operand_multidim_index, operand_shape, index_type_);
324 }
325 
SourceIndexOfBitcast(const Shape & shape,const Shape & operand_shape,llvm::IRBuilder<> * builder) const326 IrArray::Index IrArray::Index::SourceIndexOfBitcast(
327     const Shape& shape, const Shape& operand_shape,
328     llvm::IRBuilder<>* builder) const {
329   CHECK(LayoutUtil::HasLayout(shape) && LayoutUtil::HasLayout(operand_shape));
330 
331   // In case the bitcast is just a reshape, we can use SourceIndexOfReshape()
332   // instead. This will reuse linear() if possible, so we don't have to build a
333   // new 'linear_index'.
334   if (ShapeUtil::ReshapeIsBitcast(operand_shape, shape)) {
335     return SourceIndexOfReshape(shape, operand_shape, builder);
336   }
337 
338   // If we have a linear index, we can definitely use it because we know the
339   // operation is a bitcast. This will recompute the multi-dimensional index for
340   // the operand based on the linear index.
341   if (linear() != nullptr) {
342     return Index(linear(), operand_shape, builder);
343   }
344 
345   // First linearize the index coming from the output of the bitcast. We want
346   // the physical index of the element in the buffer. This is like Linearize,
347   // but takes the layout into account.
348   int64_t scale = 1;
349   llvm::Value* linear_index = GetConstantWithIndexType(0);
350   for (auto dimension : LayoutUtil::MinorToMajor(shape)) {
351     linear_index = builder->CreateAdd(
352         linear_index,
353         builder->CreateMul(multidim_[dimension],
354                            GetConstantWithIndexType(scale), "",
355                            /*HasNUW=*/true, /*HasNSW=*/true),
356         "", /*HasNUW=*/true, /*HasNSW=*/true);
357     scale *= shape.dimensions(dimension);
358   }
359 
360   return Index(linear_index, operand_shape, builder);
361 }
362 
SourceIndexOfBroadcast(const Shape & shape,const Shape & operand_shape,absl::Span<const int64> dimension_mapping,llvm::IRBuilder<> * builder) const363 IrArray::Index IrArray::Index::SourceIndexOfBroadcast(
364     const Shape& shape, const Shape& operand_shape,
365     absl::Span<const int64> dimension_mapping,
366     llvm::IRBuilder<>* builder) const {
367   int64_t rank = operand_shape.rank();
368   std::vector<llvm::Value*> source_index(rank);
369   for (int64_t i = 0; i < rank; ++i) {
370     source_index[i] = multidim_[dimension_mapping[i]];
371   }
372   if (linear_ == nullptr || !LayoutUtil::HasLayout(operand_shape) ||
373       !LayoutUtil::HasLayout(shape) || rank == 1) {
374     return Index(source_index, operand_shape, index_type_);
375   }
376   // High-level idea: we can reuse the linear index if the broadcasted
377   // dimensions are contiguous, and this part of the operation is a bitcast.
378   // The other dimensions can be masked out with a div and a mod operation.
379   std::vector<int64> logical_to_physical =
380       LayoutUtil::MakeLogicalToPhysical(shape.layout());
381   int64_t output_rank = shape.rank();
382   // The minimum physical dimension that is broadcasted.
383   int64_t min_broadcasted_dimension = output_rank;
384   // The maximum physical dimension that is broadcasted.
385   int64_t max_broadcasted_dimension = -1;
386   for (int64_t i = 0; i < rank; ++i) {
387     int64_t physical_dim = logical_to_physical[dimension_mapping[i]];
388     min_broadcasted_dimension =
389         std::min(min_broadcasted_dimension, physical_dim);
390     max_broadcasted_dimension =
391         std::max(max_broadcasted_dimension, physical_dim);
392   }
393   bool contiguous_broadcast_dimensions =
394       max_broadcasted_dimension - min_broadcasted_dimension == rank - 1;
395   if (!contiguous_broadcast_dimensions) {
396     return Index(source_index, operand_shape, index_type_);
397   }
398   // Check if the mapped dimensions are a bitcast.
399   std::vector<int64> operand_logical_to_physical =
400       LayoutUtil::MakeLogicalToPhysical(operand_shape.layout());
401   for (int64_t i = 0; i < rank; ++i) {
402     if (operand_logical_to_physical[i] !=
403         logical_to_physical[dimension_mapping[i]] - min_broadcasted_dimension) {
404       return Index(source_index, operand_shape, index_type_);
405     }
406   }
407   llvm::Value* linear = linear_;
408   int64_t divisor = 1;
409   for (int64_t i = max_broadcasted_dimension + 1; i < output_rank; ++i) {
410     divisor *= shape.dimensions(LayoutUtil::Major(shape.layout(), i));
411   }
412   if (divisor > 1) {
413     linear = builder->CreateUDiv(linear, GetConstantWithIndexType(divisor));
414   }
415   if (min_broadcasted_dimension > 0) {
416     int64_t mod = 1;
417     for (int64_t i = min_broadcasted_dimension; i <= max_broadcasted_dimension;
418          ++i) {
419       mod *= shape.dimensions(LayoutUtil::Major(shape.layout(), i));
420     }
421     linear = builder->CreateURem(linear, GetConstantWithIndexType(mod));
422   }
423   return Index(source_index, linear, operand_shape, index_type_);
424 }
425 
Linearize(absl::Span<const int64> dimensions,llvm::IRBuilder<> * builder) const426 llvm::Value* IrArray::Index::Linearize(absl::Span<const int64> dimensions,
427                                        llvm::IRBuilder<>* builder) const {
428   // Each dimension is multiplied by the product of the sizes of all
429   // earlier dimensions and added to the accumulator logical_linear_index.
430   CHECK_EQ(size(), dimensions.size());
431   llvm::Value* logical_linear_index = GetConstantWithIndexType(0);
432   int64_t multiplier = 1;
433   for (ssize_t i = size() - 1; i >= 0; --i) {
434     llvm::Value* addend =
435         builder->CreateMul((*this)[i], GetConstantWithIndexType(multiplier), "",
436                            /*HasNUW=*/true, /*HasNSW=*/true);
437     addend = builder->CreateZExtOrTrunc(addend, index_type_);
438     logical_linear_index = builder->CreateAdd(logical_linear_index, addend, "",
439                                               /*HasNUW=*/true, /*HasNSW=*/true);
440     multiplier *= dimensions[i];
441   }
442   return logical_linear_index;
443 }
444 
Linearize(const std::vector<llvm::Value * > & dynamic_dims,llvm::IRBuilder<> * builder) const445 llvm::Value* IrArray::Index::Linearize(
446     const std::vector<llvm::Value*>& dynamic_dims,
447     llvm::IRBuilder<>* builder) const {
448   // Each dimension is multiplied by the product of the sizes of all
449   // earlier dimensions and added to the accumulator logical_linear_index.
450   CHECK_EQ(size(), dynamic_dims.size());
451   llvm::Value* logical_linear_index = GetConstantWithIndexType(0);
452   llvm::Value* multiplier = GetConstantWithIndexType(1);
453   for (ssize_t i = size() - 1; i >= 0; --i) {
454     llvm::Value* addend = builder->CreateMul((*this)[i], multiplier, "",
455                                              /*HasNUW=*/true, /*HasNSW=*/true);
456     addend = builder->CreateZExtOrTrunc(addend, index_type_);
457     logical_linear_index = builder->CreateAdd(logical_linear_index, addend, "",
458                                               /*HasNUW=*/true, /*HasNSW=*/true);
459     if (i) {
460       multiplier = builder->CreateMul(multiplier, dynamic_dims[i],
461                                       /*Name=*/"multiplier");
462     }
463   }
464   return logical_linear_index;
465 }
466 
EmitArrayElementAddress(const IrArray::Index & index,llvm::IRBuilder<> * b,absl::string_view name,bool use_linear_index) const467 llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index,
468                                               llvm::IRBuilder<>* b,
469                                               absl::string_view name,
470                                               bool use_linear_index) const {
471   if (ShapeUtil::IsScalar(shape_)) {
472     // Special handling of scalars: a scalar pretends to have the same value for
473     // every index, thus effectively implementing broadcasting of its value
474     // over higher-rank arrays.
475     return base_ptr_;
476   }
477   CHECK_EQ(index.size(), shape_.rank());
478   CHECK(index.ShapeIsCompatible(shape_));
479 
480   if (use_linear_index && index.LinearValidOnShape(shape_)) {
481     llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
482     return b->CreateInBoundsGEP(
483         b->CreateBitCast(base_ptr_,
484                          PrimitiveTypeToIrType(shape_.element_type(), module)
485                              ->getPointerTo()),
486         {index.linear()}, llvm_ir::AsStringRef(name));
487   }
488 
489   std::vector<llvm::Value*> actual_index;
490   for (int64_t i = 0; i < index.size(); ++i) {
491     // When dimension i is of size 1, LLVM optimization is able to replace
492     // index[i] with 0. However, setting index[i] to 0 here still allows LLVM to
493     // produce better code in some cases.
494     auto dim = shape_.dimensions(i);
495     actual_index.push_back(
496         dim == 1 ? llvm::ConstantInt::get(index[i]->getType(), 0) : index[i]);
497   }
498 
499   // "base_ptr_" has the type of "<ir_type_for_its_shape>*"
500   // (e.g. [3 x [2 x float]]*). Therefore, the address of the indexed element
501   // should be computed by
502   //
503   //   getelementptr base_ptr_, 0, most major index, ..., most minor index
504   CHECK_GT(index.size(), 0);
505   std::vector<llvm::Value*> gep_indices(
506       1, llvm::ConstantInt::get(index[0]->getType(), 0));
507   for (int64_t i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) {
508     int64_t dimension = LayoutUtil::Major(shape_.layout(), i);
509     gep_indices.push_back(actual_index[dimension]);
510   }
511   return b->CreateInBoundsGEP(base_ptr_, gep_indices,
512                               llvm_ir::AsStringRef(name));
513 }
514 
AnnotateLoadStoreInstructionWithMetadata(llvm::Instruction * instruction) const515 void IrArray::AnnotateLoadStoreInstructionWithMetadata(
516     llvm::Instruction* instruction) const {
517   CHECK(llvm::isa<llvm::LoadInst>(instruction) ||
518         llvm::isa<llvm::StoreInst>(instruction));
519   CHECK(!llvm::isa<llvm::StoreInst>(instruction) || !is_invariant_)
520       << "Trying to create a store to an invariant IRArray.";
521 
522   for (const auto& kind_md_pair : metadata_) {
523     instruction->setMetadata(kind_md_pair.first, kind_md_pair.second);
524   }
525 }
526 
EmitReadArrayElement(const Index & index,llvm::IRBuilder<> * b,absl::string_view name,bool use_linear_index) const527 llvm::Value* IrArray::EmitReadArrayElement(const Index& index,
528                                            llvm::IRBuilder<>* b,
529                                            absl::string_view name,
530                                            bool use_linear_index) const {
531   llvm::Value* element_address =
532       EmitArrayElementAddress(index, b, name, use_linear_index);
533   llvm::LoadInst* load = b->CreateLoad(element_address, name.data());
534   AnnotateLoadStoreInstructionWithMetadata(load);
535   return load;
536 }
537 
EmitWriteArrayElement(const Index & index,llvm::Value * value,llvm::IRBuilder<> * b,bool use_linear_index) const538 void IrArray::EmitWriteArrayElement(const Index& index, llvm::Value* value,
539                                     llvm::IRBuilder<>* b,
540                                     bool use_linear_index) const {
541   llvm::Value* element_address =
542       EmitArrayElementAddress(index, b, "", use_linear_index);
543   llvm::StoreInst* store = b->CreateStore(value, element_address);
544   AnnotateLoadStoreInstructionWithMetadata(store);
545 }
546 
CastToShape(const Shape & new_shape,llvm::IRBuilder<> * b) const547 IrArray IrArray::CastToShape(const Shape& new_shape,
548                              llvm::IRBuilder<>* b) const {
549   llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
550   llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(new_shape, module);
551   IrArray new_irarray(
552       b->CreatePointerCast(base_ptr_, new_ir_type->getPointerTo()), new_shape);
553   new_irarray.metadata_ = metadata_;
554   return new_irarray;
555 }
556 
ShapeIsCompatible(const Shape & a,const Shape & b)557 bool IrArray::Index::ShapeIsCompatible(const Shape& a, const Shape& b) {
558   // Compute strides for two sides of the comparison. Sometimes different shapes
559   // give the same strides:
560   //   [10, 20, 30, 1]{3,2,1,0} vs [10, 20, 1, 30]{3,2,1,0}
561   // which should be considered compatible.
562   const auto get_strides = [](const Shape& shape) {
563     int rank = shape.dimensions().size();
564     int64_t stride = 1;
565     std::vector<int64> strides;
566     for (int i = 0; i < rank; i++) {
567       auto dim = shape.dimensions(shape.layout().minor_to_major(i));
568       if (dim != 1) {
569         stride *= dim;
570         strides.push_back(stride);
571       }
572     }
573     return strides;
574   };
575 
576   return get_strides(a) == get_strides(b);
577 }
578 
579 }  // namespace llvm_ir
580 }  // namespace xla
581