• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
17 
18 #include "llvm/IR/Constants.h"
19 #include "llvm/IR/Instructions.h"
20 #include "tensorflow/compiler/xla/layout_util.h"
21 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/compiler/xla/statusor.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/types.h"
28 
29 namespace xla {
30 namespace llvm_ir {
31 
Index(absl::Span<llvm::Value * const> multidim,llvm::Value * linear,const Shape & shape,llvm::Type * index_type)32 IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
33                       llvm::Value* linear, const Shape& shape,
34                       llvm::Type* index_type)
35     : Index(multidim, shape, index_type) {
36   CHECK_NE(linear, nullptr);
37   linear_ = linear;
38 }
39 
Delinearize(std::vector<llvm::Value * > * multidim,llvm::Value * linear,const Shape & shape,llvm::IRBuilder<> * b) const40 void IrArray::Index::Delinearize(std::vector<llvm::Value*>* multidim,
41                                  llvm::Value* linear, const Shape& shape,
42                                  llvm::IRBuilder<>* b) const {
43   int64 divisor = 1;
44   const Layout& layout = shape.layout();
45   for (int64 i = 0; i < layout.minor_to_major_size(); ++i) {
46     int64 dimension = layout.minor_to_major(i);
47     int64 size_of_current_dimension = shape.dimensions(dimension);
48 
49     // If i is not the last dimension, compute
50     //   (linear_index / divisor) % current_dimension.
51     // If i is the last dimension, we can skip the mod, because we assume that
52     // linear is in bounds.
53     //
54     // TODO(jlebar): We could add bounds checks here and elsewhere in this file,
55     // guarded under some sort of xla-memcheck flag.  This might be particularly
56     // useful because cuda-memcheck can't help us much in XLA: Most of our
57     // memory lives in one big allocation, so cuda-memcheck can't detect
58     // out-of-bounds accesses.
59     auto* quot = b->CreateUDiv(linear, GetConstantWithIndexType(divisor));
60     if (i < layout.minor_to_major_size() - 1) {
61       (*multidim)[dimension] = b->CreateURem(
62           quot, GetConstantWithIndexType(size_of_current_dimension));
63     } else {
64       (*multidim)[dimension] = quot;
65     }
66     divisor *= size_of_current_dimension;
67   }
68 }
69 
Index(llvm::Value * linear,const Shape & shape,llvm::IRBuilder<> * b)70 IrArray::Index::Index(llvm::Value* linear, const Shape& shape,
71                       llvm::IRBuilder<>* b)
72     : multidim_(shape.rank()),
73       linear_(linear),
74       layout_(shape.layout()),
75       dims_(shape.dimensions().begin(), shape.dimensions().end()) {
76   CHECK_NE(linear, nullptr);
77   index_type_ = linear->getType();
78   CHECK(LayoutUtil::HasLayout(shape))
79       << "Shape " << ShapeUtil::HumanStringWithLayout(shape)
80       << " should have a layout.";
81   Delinearize(&multidim_, linear, shape, b);
82 }
83 
Index(absl::Span<llvm::Value * const> multidim,absl::Span<int64 const> dimensions,llvm::Type * index_type)84 IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
85                       absl::Span<int64 const> dimensions,
86                       llvm::Type* index_type)
87     : Index(multidim, ShapeUtil::MakeShape(/*arbitrary*/ PRED, dimensions),
88             index_type) {}
89 
Index(absl::Span<llvm::Value * const> multidim,const Shape & shape,llvm::Type * index_type)90 IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
91                       const Shape& shape, llvm::Type* index_type)
92     : multidim_(multidim.begin(), multidim.end()),
93       linear_(nullptr),
94       layout_(shape.layout()),
95       dims_(shape.dimensions().begin(), shape.dimensions().end()),
96       index_type_(index_type) {
97   CHECK_NE(index_type_, nullptr);
98   CHECK_EQ(shape.dimensions_size(), multidim.size());
99   for (const auto* dim : multidim) {
100     CHECK_NE(dim, nullptr);
101   }
102   CHECK(LayoutUtil::HasLayout(shape))
103       << "Shape " << ShapeUtil::HumanStringWithLayout(shape)
104       << " should have a layout.";
105 }
106 
IrArray(llvm::Value * base_ptr,Shape shape)107 IrArray::IrArray(llvm::Value* base_ptr, Shape shape)
108     : base_ptr_(base_ptr), shape_(std::move(shape)) {
109   TF_CHECK_OK(ShapeUtil::ValidateShape(shape));
110   CHECK(base_ptr_->getType()->isPointerTy());
111   int depth = 0;
112   element_type_ =
113       llvm::cast<llvm::PointerType>(base_ptr_->getType())->getElementType();
114   while (llvm::ArrayType* array_type =
115              llvm::dyn_cast<llvm::ArrayType>(element_type_)) {
116     element_type_ = array_type->getElementType();
117     ++depth;
118   }
119 
120   if (!shape_.IsArray() || ShapeUtil::IsScalar(shape_)) {
121     DCHECK(depth == 1 || depth == 0) << depth;
122   } else {
123     DCHECK_EQ(depth, shape_.rank()) << shape.ShortDebugString();
124   }
125 }
126 
127 // Returns whether the given linear index is valid on the given shape.
LinearValidOnShape(const Shape & a) const128 bool IrArray::Index::LinearValidOnShape(const Shape& a) const {
129   auto b = ShapeUtil::MakeShape(a.element_type(), dims_);
130   *b.mutable_layout() = layout_;
131   return linear_ != nullptr &&
132          ShapeUtil::ElementsIn(a) == ShapeUtil::ElementsIn(b) &&
133          ShapeUtil::ReshapeIsBitcast(a, b);
134 }
135 
SourceIndexOfReshape(const Shape & output_shape,const Shape & input_shape,llvm::IRBuilder<> * builder) const136 IrArray::Index IrArray::Index::SourceIndexOfReshape(
137     const Shape& output_shape, const Shape& input_shape,
138     llvm::IRBuilder<>* builder) const {
139   CHECK_EQ(multidim_.size(), output_shape.rank());
140   const auto common_factors =
141       CommonFactors(AsInt64Slice(input_shape.dimensions()),
142                     AsInt64Slice(output_shape.dimensions()));
143   std::vector<llvm::Value*> source_multidim_index(
144       input_shape.rank(), llvm::UndefValue::get(index_type_));
145   // We compute the source indices in each common factor from only the target
146   // indices in the same common factor.
147   for (ssize_t k = common_factors.size() - 2; k >= 0; --k) {
148     absl::Span<int64 const> dimensions =
149         AsInt64Slice(output_shape.dimensions())
150             .subspan(common_factors[k].second,
151                      common_factors[k + 1].second - common_factors[k].second);
152     llvm::Value* logical_linear_index =
153         Index(absl::Span<llvm::Value* const>(multidim_).subspan(
154                   common_factors[k].second,
155                   common_factors[k + 1].second - common_factors[k].second),
156               dimensions, index_type_)
157             .Linearize(dimensions, builder);
158     // Delinearizes logical_linear_index for the source array in row-major
159     // collapsed order. The first rank-1 indices are the remainder of the
160     // linear index by each dimension size.
161     for (int64 i = common_factors[k + 1].first - 1;
162          i >= common_factors[k].first; --i) {
163       llvm::Value* divisor =
164           GetConstantWithIndexType(input_shape.dimensions(i));
165       if (input_shape.dimensions(i) == 1) {
166         source_multidim_index[i] = GetConstantWithIndexType(0);
167       } else if (i == common_factors[k].first) {
168         source_multidim_index[i] = logical_linear_index;
169       } else {
170         source_multidim_index[i] =
171             builder->CreateURem(logical_linear_index, divisor);
172       }
173       logical_linear_index = builder->CreateUDiv(logical_linear_index, divisor);
174     }
175   }
176 
177   if (linear() != nullptr && LayoutUtil::HasLayout(input_shape) &&
178       LayoutUtil::HasLayout(output_shape) &&
179       ShapeUtil::ReshapeIsBitcast(input_shape, output_shape)) {
180     return Index(source_multidim_index, linear(), input_shape, index_type_);
181   }
182   return Index(source_multidim_index, input_shape, index_type_);
183 }
184 
SourceIndexOfSlice(const Shape & operand_shape,absl::Span<const int64> starts,absl::Span<const int64> strides,llvm::IRBuilder<> * builder) const185 IrArray::Index IrArray::Index::SourceIndexOfSlice(
186     const Shape& operand_shape, absl::Span<const int64> starts,
187     absl::Span<const int64> strides, llvm::IRBuilder<>* builder) const {
188   std::vector<llvm::Value*> source_multi_index(multidim_.size());
189   for (int i = 0; i < multidim_.size(); ++i) {
190     int64 stride = strides[i];
191     auto type = multidim_[i]->getType();
192 
193     if (stride != 1) {
194       source_multi_index[i] = builder->CreateAdd(
195           builder->CreateMul(multidim_[i],
196                              llvm::ConstantInt::get(type, stride)),
197           llvm::ConstantInt::get(type, starts[i]));
198     } else {
199       source_multi_index[i] = builder->CreateAdd(
200           multidim_[i], llvm::ConstantInt::get(type, starts[i]));
201     }
202   }
203   return Index(source_multi_index, operand_shape, index_type_);
204 }
205 
SourceIndexOfTranspose(const Shape & shape,const Shape & operand_shape,absl::Span<const int64> dimension_mapping,llvm::IRBuilder<> * builder) const206 IrArray::Index IrArray::Index::SourceIndexOfTranspose(
207     const Shape& shape, const Shape& operand_shape,
208     absl::Span<const int64> dimension_mapping,
209     llvm::IRBuilder<>* builder) const {
210   std::vector<llvm::Value*> operand_multidim_index =
211       Permute(dimension_mapping, multidim());
212 
213   if (linear() != nullptr && LayoutUtil::HasLayout(operand_shape) &&
214       LayoutUtil::HasLayout(shape) &&
215       ShapeUtil::TransposeIsBitcast(operand_shape, shape, dimension_mapping)) {
216     return Index(operand_multidim_index, linear(), operand_shape, index_type_);
217   }
218 
219   return Index(operand_multidim_index, operand_shape, index_type_);
220 }
221 
SourceIndexOfBitcast(const Shape & shape,const Shape & operand_shape,llvm::IRBuilder<> * builder) const222 IrArray::Index IrArray::Index::SourceIndexOfBitcast(
223     const Shape& shape, const Shape& operand_shape,
224     llvm::IRBuilder<>* builder) const {
225   CHECK(LayoutUtil::HasLayout(shape) && LayoutUtil::HasLayout(operand_shape));
226   // In case the bitcast is just a reshape, we can use SourceIndexOfReshape()
227   // instead. This will reuse linear() if possible, so we don't have to build a
228   // new 'linear_index'.
229   if (ShapeUtil::ReshapeIsBitcast(operand_shape, shape)) {
230     return SourceIndexOfReshape(shape, operand_shape, builder);
231   }
232 
233   // First linearize the index coming from the output of the bitcast. We want
234   // the physical index of the element in the buffer. This is like Linearize,
235   // but takes the layout into account.
236   int64 scale = 1;
237   llvm::Value* linear_index = GetConstantWithIndexType(0);
238   for (auto dimension : LayoutUtil::MinorToMajor(shape)) {
239     linear_index = builder->CreateAdd(
240         linear_index,
241         builder->CreateMul(multidim_[dimension],
242                            GetConstantWithIndexType(scale), "",
243                            /*HasNUW=*/true, /*HasNSW=*/true),
244         "", /*HasNUW=*/true, /*HasNSW=*/true);
245     scale *= shape.dimensions(dimension);
246   }
247 
248   return Index(linear_index, operand_shape, builder);
249 }
250 
SourceIndexOfBroadcast(const Shape & shape,const Shape & operand_shape,absl::Span<const int64> dimension_mapping,llvm::IRBuilder<> * builder) const251 IrArray::Index IrArray::Index::SourceIndexOfBroadcast(
252     const Shape& shape, const Shape& operand_shape,
253     absl::Span<const int64> dimension_mapping,
254     llvm::IRBuilder<>* builder) const {
255   int64 rank = operand_shape.rank();
256   std::vector<llvm::Value*> source_index(rank);
257   for (int64 i = 0; i < rank; ++i) {
258     source_index[i] = multidim_[dimension_mapping[i]];
259   }
260   if (linear_ == nullptr || !LayoutUtil::HasLayout(operand_shape) ||
261       !LayoutUtil::HasLayout(shape)) {
262     return Index(source_index, operand_shape, index_type_);
263   }
264   // High-level idea: we can reuse the linear index if the broadcasted
265   // dimensions are contiguous, and this part of the operation is a bitcast.
266   // The other dimensions can be masked out with a div and a mod operation.
267   std::vector<int64> logical_to_physical =
268       LayoutUtil::MakeLogicalToPhysical(shape.layout());
269   int64 output_rank = shape.rank();
270   // The minimum physical dimension that is broadcasted.
271   int64 min_broadcasted_dimension = output_rank;
272   // The maximum physical dimension that is broadcasted.
273   int64 max_broadcasted_dimension = -1;
274   for (int64 i = 0; i < rank; ++i) {
275     int64 physical_dim = logical_to_physical[dimension_mapping[i]];
276     min_broadcasted_dimension =
277         std::min(min_broadcasted_dimension, physical_dim);
278     max_broadcasted_dimension =
279         std::max(max_broadcasted_dimension, physical_dim);
280   }
281   bool contiguous_broadcast_dimensions =
282       max_broadcasted_dimension - min_broadcasted_dimension == rank - 1;
283   if (!contiguous_broadcast_dimensions) {
284     return Index(source_index, operand_shape, index_type_);
285   }
286   // Check if the mapped dimensions are a bitcast.
287   std::vector<int64> operand_logical_to_physical =
288       LayoutUtil::MakeLogicalToPhysical(operand_shape.layout());
289   for (int64 i = 0; i < rank; ++i) {
290     if (operand_logical_to_physical[i] !=
291         logical_to_physical[dimension_mapping[i]] - min_broadcasted_dimension) {
292       return Index(source_index, operand_shape, index_type_);
293     }
294   }
295   llvm::Value* linear = linear_;
296   int64 divisor = 1;
297   for (int64 i = max_broadcasted_dimension + 1; i < output_rank; ++i) {
298     divisor *= shape.dimensions(LayoutUtil::Major(shape.layout(), i));
299   }
300   if (divisor > 1) {
301     linear = builder->CreateUDiv(linear, GetConstantWithIndexType(divisor));
302   }
303   if (min_broadcasted_dimension > 0) {
304     int64 mod = 1;
305     for (int64 i = min_broadcasted_dimension; i <= max_broadcasted_dimension;
306          ++i) {
307       mod *= shape.dimensions(LayoutUtil::Major(shape.layout(), i));
308     }
309     linear = builder->CreateURem(linear, GetConstantWithIndexType(mod));
310   }
311   return Index(source_index, linear, operand_shape, index_type_);
312 }
313 
Linearize(absl::Span<const int64> dimensions,llvm::IRBuilder<> * builder) const314 llvm::Value* IrArray::Index::Linearize(absl::Span<const int64> dimensions,
315                                        llvm::IRBuilder<>* builder) const {
316   // Each dimension is multiplied by the product of the sizes of all
317   // earlier dimensions and added to the accumulator logical_linear_index.
318   CHECK_EQ(size(), dimensions.size());
319   llvm::Value* logical_linear_index = GetConstantWithIndexType(0);
320   int64 multiplier = 1;
321   for (ssize_t i = size() - 1; i >= 0; --i) {
322     llvm::Value* addend =
323         builder->CreateMul((*this)[i], GetConstantWithIndexType(multiplier), "",
324                            /*HasNUW=*/true, /*HasNSW=*/true);
325     addend = builder->CreateZExtOrTrunc(addend, index_type_);
326     logical_linear_index = builder->CreateAdd(logical_linear_index, addend, "",
327                                               /*HasNUW=*/true, /*HasNSW=*/true);
328     multiplier *= dimensions[i];
329   }
330   return logical_linear_index;
331 }
332 
EmitArrayElementAddress(const IrArray::Index & index,llvm::IRBuilder<> * b,absl::string_view name,bool use_linear_index) const333 llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index,
334                                               llvm::IRBuilder<>* b,
335                                               absl::string_view name,
336                                               bool use_linear_index) const {
337   if (ShapeUtil::IsScalar(shape_)) {
338     // Special handling of scalars: a scalar pretends to have the same value for
339     // every index, thus effectively implementing broadcasting of its value
340     // over higher-rank arrays.
341     return base_ptr_;
342   }
343   CHECK_EQ(index.size(), shape_.rank());
344   CHECK(index.ShapeIsCompatible(shape_));
345 
346   if (use_linear_index && index.LinearValidOnShape(shape_)) {
347     llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
348     return b->CreateInBoundsGEP(
349         b->CreateBitCast(base_ptr_,
350                          PrimitiveTypeToIrType(shape_.element_type(), module)
351                              ->getPointerTo()),
352         {index.linear()}, llvm_ir::AsStringRef(name));
353   }
354 
355   std::vector<llvm::Value*> actual_index;
356   for (int64 i = 0; i < index.size(); ++i) {
357     // When dimension i is of size 1, LLVM optimization is able to replace
358     // index[i] with 0. However, setting index[i] to 0 here still allows LLVM to
359     // produce better code in some cases.
360     auto dim = shape_.dimensions(i);
361     actual_index.push_back(
362         dim == 1 ? llvm::ConstantInt::get(index[i]->getType(), 0) : index[i]);
363   }
364 
365   // "base_ptr_" has the type of "<ir_type_for_its_shape>*"
366   // (e.g. [3 x [2 x float]]*). Therefore, the address of the indexed element
367   // should be computed by
368   //
369   //   getelementptr base_ptr_, 0, most major index, ..., most minor index
370   CHECK_GT(index.size(), 0);
371   std::vector<llvm::Value*> gep_indices(
372       1, llvm::ConstantInt::get(index[0]->getType(), 0));
373   for (int64 i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) {
374     int64 dimension = LayoutUtil::Major(shape_.layout(), i);
375     gep_indices.push_back(actual_index[dimension]);
376   }
377   return b->CreateInBoundsGEP(base_ptr_, gep_indices,
378                               llvm_ir::AsStringRef(name));
379 }
380 
AnnotateLoadStoreInstructionWithMetadata(llvm::Instruction * instruction) const381 void IrArray::AnnotateLoadStoreInstructionWithMetadata(
382     llvm::Instruction* instruction) const {
383   CHECK(llvm::isa<llvm::LoadInst>(instruction) ||
384         llvm::isa<llvm::StoreInst>(instruction));
385   CHECK(!llvm::isa<llvm::StoreInst>(instruction) || !is_invariant_)
386       << "Trying to create a store to an invariant IRArray.";
387 
388   for (const auto& kind_md_pair : metadata_) {
389     instruction->setMetadata(kind_md_pair.first, kind_md_pair.second);
390   }
391 }
392 
EmitReadArrayElement(const Index & index,llvm::IRBuilder<> * b,absl::string_view name,bool use_linear_index) const393 llvm::Value* IrArray::EmitReadArrayElement(const Index& index,
394                                            llvm::IRBuilder<>* b,
395                                            absl::string_view name,
396                                            bool use_linear_index) const {
397   llvm::Value* element_address =
398       EmitArrayElementAddress(index, b, name, use_linear_index);
399   llvm::LoadInst* load = b->CreateLoad(element_address);
400   AnnotateLoadStoreInstructionWithMetadata(load);
401   return load;
402 }
403 
EmitWriteArrayElement(const Index & index,llvm::Value * value,llvm::IRBuilder<> * b,bool use_linear_index) const404 void IrArray::EmitWriteArrayElement(const Index& index, llvm::Value* value,
405                                     llvm::IRBuilder<>* b,
406                                     bool use_linear_index) const {
407   llvm::Value* element_address =
408       EmitArrayElementAddress(index, b, "", use_linear_index);
409   llvm::StoreInst* store = b->CreateStore(value, element_address);
410   AnnotateLoadStoreInstructionWithMetadata(store);
411 }
412 
CastToShape(const Shape & new_shape,llvm::IRBuilder<> * b) const413 IrArray IrArray::CastToShape(const Shape& new_shape,
414                              llvm::IRBuilder<>* b) const {
415   llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
416   llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(new_shape, module);
417   IrArray new_irarray(
418       b->CreatePointerCast(base_ptr_, new_ir_type->getPointerTo()), new_shape);
419   new_irarray.metadata_ = metadata_;
420   return new_irarray;
421 }
422 
423 }  // namespace llvm_ir
424 }  // namespace xla
425