1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
17
18 #include "llvm/IR/Constants.h"
19 #include "llvm/IR/Instructions.h"
20 #include "tensorflow/compiler/xla/layout_util.h"
21 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
22 #include "tensorflow/compiler/xla/shape_util.h"
23 #include "tensorflow/compiler/xla/statusor.h"
24 #include "tensorflow/compiler/xla/util.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 #include "tensorflow/core/platform/logging.h"
27 #include "tensorflow/core/platform/types.h"
28
29 namespace xla {
30 namespace llvm_ir {
31
Index(absl::Span<llvm::Value * const> multidim,llvm::Value * linear,const Shape & shape,llvm::Type * index_type)32 IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
33 llvm::Value* linear, const Shape& shape,
34 llvm::Type* index_type)
35 : Index(multidim, shape, index_type) {
36 CHECK_NE(linear, nullptr);
37 linear_ = linear;
38 }
39
Delinearize(std::vector<llvm::Value * > * multidim,llvm::Value * linear,const Shape & shape,llvm::IRBuilder<> * b) const40 void IrArray::Index::Delinearize(std::vector<llvm::Value*>* multidim,
41 llvm::Value* linear, const Shape& shape,
42 llvm::IRBuilder<>* b) const {
43 int64 divisor = 1;
44 const Layout& layout = shape.layout();
45 for (int64 i = 0; i < layout.minor_to_major_size(); ++i) {
46 int64 dimension = layout.minor_to_major(i);
47 int64 size_of_current_dimension = shape.dimensions(dimension);
48
49 // If i is not the last dimension, compute
50 // (linear_index / divisor) % current_dimension.
51 // If i is the last dimension, we can skip the mod, because we assume that
52 // linear is in bounds.
53 //
54 // TODO(jlebar): We could add bounds checks here and elsewhere in this file,
55 // guarded under some sort of xla-memcheck flag. This might be particularly
56 // useful because cuda-memcheck can't help us much in XLA: Most of our
57 // memory lives in one big allocation, so cuda-memcheck can't detect
58 // out-of-bounds accesses.
59 auto* quot = b->CreateUDiv(linear, GetConstantWithIndexType(divisor));
60 if (i < layout.minor_to_major_size() - 1) {
61 (*multidim)[dimension] = b->CreateURem(
62 quot, GetConstantWithIndexType(size_of_current_dimension));
63 } else {
64 (*multidim)[dimension] = quot;
65 }
66 divisor *= size_of_current_dimension;
67 }
68 }
69
Index(llvm::Value * linear,const Shape & shape,llvm::IRBuilder<> * b)70 IrArray::Index::Index(llvm::Value* linear, const Shape& shape,
71 llvm::IRBuilder<>* b)
72 : multidim_(shape.rank()),
73 linear_(linear),
74 layout_(shape.layout()),
75 dims_(shape.dimensions().begin(), shape.dimensions().end()) {
76 CHECK_NE(linear, nullptr);
77 index_type_ = linear->getType();
78 CHECK(LayoutUtil::HasLayout(shape))
79 << "Shape " << ShapeUtil::HumanStringWithLayout(shape)
80 << " should have a layout.";
81 Delinearize(&multidim_, linear, shape, b);
82 }
83
Index(absl::Span<llvm::Value * const> multidim,const Shape & shape,llvm::Type * index_type)84 IrArray::Index::Index(absl::Span<llvm::Value* const> multidim,
85 const Shape& shape, llvm::Type* index_type)
86 : multidim_(multidim.begin(), multidim.end()),
87 linear_(nullptr),
88 layout_(shape.layout()),
89 dims_(shape.dimensions().begin(), shape.dimensions().end()),
90 index_type_(index_type) {
91 CHECK_NE(index_type_, nullptr);
92 CHECK_EQ(shape.dimensions_size(), multidim.size());
93 for (const auto* dim : multidim) {
94 CHECK_NE(dim, nullptr);
95 }
96 CHECK(LayoutUtil::HasLayout(shape))
97 << "Shape " << ShapeUtil::HumanStringWithLayout(shape)
98 << " should have a layout.";
99 }
100
IrArray(llvm::Value * base_ptr,Shape shape)101 IrArray::IrArray(llvm::Value* base_ptr, Shape shape)
102 : base_ptr_(base_ptr), shape_(std::move(shape)) {
103 TF_CHECK_OK(ShapeUtil::ValidateShape(shape));
104 CHECK(base_ptr_->getType()->isPointerTy());
105 int depth = 0;
106 element_type_ =
107 llvm::cast<llvm::PointerType>(base_ptr_->getType())->getElementType();
108 while (llvm::ArrayType* array_type =
109 llvm::dyn_cast<llvm::ArrayType>(element_type_)) {
110 element_type_ = array_type->getElementType();
111 ++depth;
112 }
113
114 if (!shape_.IsArray() || ShapeUtil::IsScalar(shape_)) {
115 DCHECK(depth == 1 || depth == 0) << depth;
116 } else {
117 DCHECK_EQ(depth, shape_.rank()) << shape.ShortDebugString();
118 }
119 }
120
121 // Returns whether the given linear index is valid on the given shape.
LinearValidOnShape(const Shape & a) const122 bool IrArray::Index::LinearValidOnShape(const Shape& a) const {
123 auto b = ShapeUtil::MakeShape(a.element_type(), dims_);
124 *b.mutable_layout() = layout_;
125 return linear_ != nullptr &&
126 ShapeUtil::ElementsIn(a) == ShapeUtil::ElementsIn(b) &&
127 ShapeUtil::ReshapeIsBitcast(a, b);
128 }
129
SourceIndexOfReshape(const Shape & output_shape,const Shape & input_shape,llvm::IRBuilder<> * builder) const130 IrArray::Index IrArray::Index::SourceIndexOfReshape(
131 const Shape& output_shape, const Shape& input_shape,
132 llvm::IRBuilder<>* builder) const {
133 const auto& target_index = *this;
134 CHECK_EQ(target_index.size(), output_shape.rank());
135 std::vector<std::pair<int64, int64>> common_factors =
136 CommonFactors(AsInt64Slice(input_shape.dimensions()),
137 AsInt64Slice(output_shape.dimensions()));
138 std::vector<llvm::Value*> source_multidim_index(
139 input_shape.rank(), llvm::UndefValue::get(index_type_));
140 // We compute the source indices in each common factor from only the target
141 // indices in the same common factor.
142 for (ssize_t k = common_factors.size() - 2; k >= 0; --k) {
143 llvm::Value* logical_linear_index =
144 Index(absl::Span<llvm::Value* const>(multidim_).subspan(
145 common_factors[k].second,
146 common_factors[k + 1].second - common_factors[k].second),
147 index_type_)
148 .Linearize(AsInt64Slice(output_shape.dimensions())
149 .subspan(common_factors[k].second,
150 common_factors[k + 1].second -
151 common_factors[k].second),
152 builder);
153 // Delinearizes logical_linear_index for the source array in row-major
154 // collapsed order. The first rank-1 indices are the remainder of the
155 // linear index by each dimension size.
156 for (int64 i = common_factors[k + 1].first - 1;
157 i >= common_factors[k].first; --i) {
158 llvm::Value* divisor =
159 GetConstantWithIndexType(input_shape.dimensions(i));
160 if (input_shape.dimensions(i) == 1) {
161 source_multidim_index[i] = GetConstantWithIndexType(0);
162 } else if (i == common_factors[k].first) {
163 source_multidim_index[i] = logical_linear_index;
164 } else {
165 source_multidim_index[i] =
166 builder->CreateURem(logical_linear_index, divisor);
167 }
168 logical_linear_index = builder->CreateUDiv(logical_linear_index, divisor);
169 }
170 }
171
172 if (linear() != nullptr && LayoutUtil::HasLayout(input_shape) &&
173 LayoutUtil::HasLayout(output_shape) &&
174 ShapeUtil::ReshapeIsBitcast(input_shape, output_shape)) {
175 return Index(source_multidim_index, linear(), input_shape, index_type_);
176 }
177 return Index(source_multidim_index, index_type_);
178 }
179
SourceIndexOfSlice(const Shape & operand_shape,absl::Span<const int64> starts,absl::Span<const int64> strides,llvm::IRBuilder<> * builder) const180 IrArray::Index IrArray::Index::SourceIndexOfSlice(
181 const Shape& operand_shape, absl::Span<const int64> starts,
182 absl::Span<const int64> strides, llvm::IRBuilder<>* builder) const {
183 std::vector<llvm::Value*> source_multi_index(multidim_.size());
184 for (int i = 0; i < multidim_.size(); ++i) {
185 int64 stride = strides[i];
186 auto type = multidim_[i]->getType();
187
188 if (stride != 1) {
189 source_multi_index[i] = builder->CreateAdd(
190 builder->CreateMul(multidim_[i],
191 llvm::ConstantInt::get(type, stride)),
192 llvm::ConstantInt::get(type, starts[i]));
193 } else {
194 source_multi_index[i] = builder->CreateAdd(
195 multidim_[i], llvm::ConstantInt::get(type, starts[i]));
196 }
197 }
198 return Index(source_multi_index, operand_shape, index_type_);
199 }
200
SourceIndexOfTranspose(const Shape & shape,const Shape & operand_shape,absl::Span<const int64> dimension_mapping,llvm::IRBuilder<> * builder) const201 IrArray::Index IrArray::Index::SourceIndexOfTranspose(
202 const Shape& shape, const Shape& operand_shape,
203 absl::Span<const int64> dimension_mapping,
204 llvm::IRBuilder<>* builder) const {
205 std::vector<llvm::Value*> operand_multidim_index =
206 Permute(dimension_mapping, multidim());
207
208 if (linear() != nullptr && LayoutUtil::HasLayout(operand_shape) &&
209 LayoutUtil::HasLayout(shape) &&
210 ShapeUtil::TransposeIsBitcast(operand_shape, shape, dimension_mapping)) {
211 return Index(operand_multidim_index, linear(), operand_shape, index_type_);
212 }
213
214 return Index(operand_multidim_index);
215 }
216
SourceIndexOfBitcast(const Shape & shape,const Shape & operand_shape,llvm::IRBuilder<> * builder) const217 IrArray::Index IrArray::Index::SourceIndexOfBitcast(
218 const Shape& shape, const Shape& operand_shape,
219 llvm::IRBuilder<>* builder) const {
220 CHECK(LayoutUtil::HasLayout(shape) && LayoutUtil::HasLayout(operand_shape));
221 // In case the bitcast is just a reshape, we can use SourceIndexOfReshape()
222 // instead. This will reuse linear() if possible, so we don't have to build a
223 // new 'linear_index'.
224 if (ShapeUtil::ReshapeIsBitcast(operand_shape, shape)) {
225 return SourceIndexOfReshape(shape, operand_shape, builder);
226 }
227
228 // First linearize the index coming from the output of the bitcast. We want
229 // the physical index of the element in the buffer. This is like Linearize,
230 // but takes the layout into account.
231 int64 scale = 1;
232 llvm::Value* linear_index = GetConstantWithIndexType(0);
233 for (auto dimension : LayoutUtil::MinorToMajor(shape)) {
234 linear_index = builder->CreateAdd(
235 linear_index,
236 builder->CreateMul(multidim_[dimension],
237 GetConstantWithIndexType(scale), "",
238 /*HasNUW=*/true, /*HasNSW=*/true),
239 "", /*HasNUW=*/true, /*HasNSW=*/true);
240 scale *= shape.dimensions(dimension);
241 }
242
243 // Now delinearize it for the input of the bitcast.
244 std::vector<llvm::Value*> multi_index(operand_shape.dimensions_size());
245 Delinearize(&multi_index, linear_index, operand_shape, builder);
246
247 return Index(multi_index, linear_index, operand_shape, index_type_);
248 }
249
SourceIndexOfBroadcast(const Shape & shape,const Shape & operand_shape,absl::Span<const int64> dimension_mapping,llvm::IRBuilder<> * builder) const250 IrArray::Index IrArray::Index::SourceIndexOfBroadcast(
251 const Shape& shape, const Shape& operand_shape,
252 absl::Span<const int64> dimension_mapping,
253 llvm::IRBuilder<>* builder) const {
254 int64 rank = operand_shape.rank();
255 std::vector<llvm::Value*> source_index(rank);
256 for (int64 i = 0; i < rank; ++i) {
257 source_index[i] = multidim_[dimension_mapping[i]];
258 }
259 if (linear_ == nullptr || !LayoutUtil::HasLayout(operand_shape) ||
260 !LayoutUtil::HasLayout(shape)) {
261 return Index(source_index, index_type_);
262 }
263 // High-level idea: we can reuse the linear index if the broadcasted
264 // dimensions are contiguous, and this part of the operation is a bitcast.
265 // The other dimensions can be masked out with a div and a mod operation.
266 std::vector<int64> logical_to_physical =
267 LayoutUtil::MakeLogicalToPhysical(shape.layout());
268 int64 output_rank = shape.rank();
269 // The minimum physical dimension that is broadcasted.
270 int64 min_broadcasted_dimension = output_rank;
271 // The maximum physical dimension that is broadcasted.
272 int64 max_broadcasted_dimension = -1;
273 for (int64 i = 0; i < rank; ++i) {
274 int64 physical_dim = logical_to_physical[dimension_mapping[i]];
275 min_broadcasted_dimension =
276 std::min(min_broadcasted_dimension, physical_dim);
277 max_broadcasted_dimension =
278 std::max(max_broadcasted_dimension, physical_dim);
279 }
280 bool contiguous_broadcast_dimensions =
281 max_broadcasted_dimension - min_broadcasted_dimension == rank - 1;
282 if (!contiguous_broadcast_dimensions) {
283 return Index(source_index, index_type_);
284 }
285 // Check if the mapped dimensions are a bitcast.
286 std::vector<int64> operand_logical_to_physical =
287 LayoutUtil::MakeLogicalToPhysical(operand_shape.layout());
288 for (int64 i = 0; i < rank; ++i) {
289 if (operand_logical_to_physical[i] !=
290 logical_to_physical[dimension_mapping[i]] - min_broadcasted_dimension) {
291 return Index(source_index, index_type_);
292 }
293 }
294 llvm::Value* linear = linear_;
295 int64 divisor = 1;
296 for (int64 i = max_broadcasted_dimension + 1; i < output_rank; ++i) {
297 divisor *= shape.dimensions(LayoutUtil::Major(shape.layout(), i));
298 }
299 if (divisor > 1) {
300 linear = builder->CreateUDiv(linear, GetConstantWithIndexType(divisor));
301 }
302 if (min_broadcasted_dimension > 0) {
303 int64 mod = 1;
304 for (int64 i = min_broadcasted_dimension; i <= max_broadcasted_dimension;
305 ++i) {
306 mod *= shape.dimensions(LayoutUtil::Major(shape.layout(), i));
307 }
308 linear = builder->CreateURem(linear, GetConstantWithIndexType(mod));
309 }
310 return Index(source_index, linear, operand_shape, index_type_);
311 }
312
Linearize(absl::Span<const int64> dimensions,llvm::IRBuilder<> * builder) const313 llvm::Value* IrArray::Index::Linearize(absl::Span<const int64> dimensions,
314 llvm::IRBuilder<>* builder) const {
315 // Each dimension is multiplied by the product of the sizes of all
316 // earlier dimensions and added to the accumulator logical_linear_index.
317 CHECK_EQ(size(), dimensions.size());
318 llvm::Value* logical_linear_index = GetConstantWithIndexType(0);
319 int64 multiplier = 1;
320 for (ssize_t i = size() - 1; i >= 0; --i) {
321 llvm::Value* addend =
322 builder->CreateMul((*this)[i], GetConstantWithIndexType(multiplier), "",
323 /*HasNUW=*/true, /*HasNSW=*/true);
324 addend = builder->CreateZExtOrTrunc(addend, index_type_);
325 logical_linear_index = builder->CreateAdd(logical_linear_index, addend, "",
326 /*HasNUW=*/true, /*HasNSW=*/true);
327 multiplier *= dimensions[i];
328 }
329 return logical_linear_index;
330 }
331
EmitArrayElementAddress(const IrArray::Index & index,llvm::IRBuilder<> * b,absl::string_view name,bool use_linear_index) const332 llvm::Value* IrArray::EmitArrayElementAddress(const IrArray::Index& index,
333 llvm::IRBuilder<>* b,
334 absl::string_view name,
335 bool use_linear_index) const {
336 if (ShapeUtil::IsScalar(shape_)) {
337 // Special handling of scalars: a scalar pretends to have the same value for
338 // every index, thus effectively implementing broadcasting of its value
339 // over higher-rank arrays.
340 return base_ptr_;
341 }
342 CHECK_EQ(index.size(), shape_.rank());
343
344 if (use_linear_index && index.LinearValidOnShape(shape_)) {
345 llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
346 return b->CreateInBoundsGEP(
347 b->CreateBitCast(base_ptr_,
348 PrimitiveTypeToIrType(shape_.element_type(), module)
349 ->getPointerTo()),
350 {index.linear()}, llvm_ir::AsStringRef(name));
351 }
352
353 std::vector<llvm::Value*> actual_index;
354 for (int64 i = 0; i < index.size(); ++i) {
355 // When dimension i is of size 1, LLVM optimization is able to replace
356 // index[i] with 0. However, setting index[i] to 0 here still allows LLVM to
357 // produce better code in some cases.
358 auto dim = shape_.dimensions(i);
359 actual_index.push_back(
360 dim == 1 ? llvm::ConstantInt::get(index[i]->getType(), 0) : index[i]);
361 }
362
363 // "base_ptr_" has the type of "<ir_type_for_its_shape>*"
364 // (e.g. [3 x [2 x float]]*). Therefore, the address of the indexed element
365 // should be computed by
366 //
367 // getelementptr base_ptr_, 0, most major index, ..., most minor index
368 CHECK_GT(index.size(), 0);
369 std::vector<llvm::Value*> gep_indices(
370 1, llvm::ConstantInt::get(index[0]->getType(), 0));
371 for (int64 i = 0; i < LayoutUtil::MinorToMajor(shape_).size(); ++i) {
372 int64 dimension = LayoutUtil::Major(shape_.layout(), i);
373 gep_indices.push_back(actual_index[dimension]);
374 }
375 return b->CreateInBoundsGEP(base_ptr_, gep_indices,
376 llvm_ir::AsStringRef(name));
377 }
378
AnnotateLoadStoreInstructionWithMetadata(llvm::Instruction * instruction) const379 void IrArray::AnnotateLoadStoreInstructionWithMetadata(
380 llvm::Instruction* instruction) const {
381 CHECK(llvm::isa<llvm::LoadInst>(instruction) ||
382 llvm::isa<llvm::StoreInst>(instruction));
383 CHECK(!llvm::isa<llvm::StoreInst>(instruction) || !is_invariant_)
384 << "Trying to create a store to an invariant IRArray.";
385
386 for (const auto& kind_md_pair : metadata_) {
387 instruction->setMetadata(kind_md_pair.first, kind_md_pair.second);
388 }
389 }
390
EmitReadArrayElement(const Index & index,llvm::IRBuilder<> * b,absl::string_view name,bool use_linear_index) const391 llvm::Value* IrArray::EmitReadArrayElement(const Index& index,
392 llvm::IRBuilder<>* b,
393 absl::string_view name,
394 bool use_linear_index) const {
395 llvm::Value* element_address =
396 EmitArrayElementAddress(index, b, name, use_linear_index);
397 llvm::LoadInst* load = b->CreateLoad(element_address);
398 AnnotateLoadStoreInstructionWithMetadata(load);
399 return load;
400 }
401
EmitWriteArrayElement(const Index & index,llvm::Value * value,llvm::IRBuilder<> * b,bool use_linear_index) const402 void IrArray::EmitWriteArrayElement(const Index& index, llvm::Value* value,
403 llvm::IRBuilder<>* b,
404 bool use_linear_index) const {
405 llvm::Value* element_address =
406 EmitArrayElementAddress(index, b, "", use_linear_index);
407 llvm::StoreInst* store = b->CreateStore(value, element_address);
408 AnnotateLoadStoreInstructionWithMetadata(store);
409 }
410
CastToShape(const Shape & new_shape,llvm::IRBuilder<> * b) const411 IrArray IrArray::CastToShape(const Shape& new_shape,
412 llvm::IRBuilder<>* b) const {
413 llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
414 llvm::Type* new_ir_type = llvm_ir::ShapeToIrType(new_shape, module);
415 IrArray new_irarray(
416 b->CreatePointerCast(base_ptr_, new_ir_type->getPointerTo()), new_shape);
417 new_irarray.metadata_ = metadata_;
418 return new_irarray;
419 }
420
421 } // namespace llvm_ir
422 } // namespace xla
423