1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_tiling.h"
17 #include "tensorflow/compiler/xla/layout_util.h"
18 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
19 #include "tensorflow/compiler/xla/shape_util.h"
20 #include "tensorflow/compiler/xla/statusor.h"
21 #include "tensorflow/compiler/xla/util.h"
22 #include "tensorflow/core/platform/logging.h"
23
24 namespace xla {
25 namespace llvm_ir {
26
27 namespace {
28 // Returns the indices of the first elements of all consecutive subarrays of the
29 // given array. For example:
30 // ConsecutiveSegments({m, m+1, m+2, n, k, k+1}) = {0, 3, 4}
ConsecutiveSegments(absl::Span<const int64> xs)31 std::vector<size_t> ConsecutiveSegments(absl::Span<const int64> xs) {
32 std::vector<size_t> is = {0};
33 for (size_t i = 1; i < xs.size(); ++i) {
34 if (1 != xs[i] - xs[i - 1]) {
35 is.push_back(i);
36 }
37 }
38 return is;
39 }
40
41 // Merges the sequences of dimensions of the given shape which start at the
42 // given indices `segs`.
MergeDimensions(absl::Span<const size_t> segs,const Shape & shape)43 Shape MergeDimensions(absl::Span<const size_t> segs, const Shape& shape) {
44 std::vector<int64> dimensions;
45 for (size_t i = 1; i <= segs.size(); ++i) {
46 dimensions.push_back(std::accumulate(
47 shape.dimensions().begin() + segs[i - 1],
48 shape.dimensions().begin() +
49 (segs.size() == i ? shape.dimensions().size() : segs[i]),
50 1, std::multiplies<int64>()));
51 }
52 return ShapeUtil::MakeShapeWithDescendingLayout(shape.element_type(),
53 dimensions);
54 }
55
56 // Given an index for a shape, return the equivalent new index if the shape is
57 // reshaped to another shape.
GetReshapedIndex(const IrArray::Index & index,const Shape & shape,const Shape & reshaped_shape,llvm::IRBuilder<> * b)58 IrArray::Index GetReshapedIndex(const IrArray::Index& index, const Shape& shape,
59 const Shape& reshaped_shape,
60 llvm::IRBuilder<>* b) {
61 auto bounds = shape.dimensions();
62 auto minor_to_major = shape.layout().minor_to_major();
63 llvm::Value* linear_index = index.GetConstantWithIndexType(0);
64 int64 multiplier = 1;
65 for (int i = 0; i < index.size(); ++i) {
66 int64 dim = minor_to_major[i];
67 llvm::Value* addend = b->CreateMul(
68 index[dim], index.GetConstantWithIndexType(multiplier), "linearizing",
69 /*HasNUW=*/true, /*HasNSW=*/true);
70 linear_index = b->CreateAdd(linear_index, addend, "",
71 /*HasNUW=*/true, /*HasNSW=*/true);
72 multiplier *= bounds[dim];
73 }
74
75 return IrArray::Index(linear_index, reshaped_shape, b);
76 }
77
78 } // namespace
79
FindTranspose021(const Shape & a,const Shape & b)80 absl::optional<std::vector<int64> > FindTranspose021(const Shape& a,
81 const Shape& b) {
82 if (!ShapeUtil::CompatibleIgnoringElementType(a, b)) {
83 return absl::nullopt;
84 }
85
86 std::vector<int64> permutation(a.dimensions().size());
87 absl::Span<const int64> minor_to_major_a = LayoutUtil::MinorToMajor(a);
88 std::vector<int64> major_to_minor_a(minor_to_major_a.rbegin(),
89 minor_to_major_a.rend());
90 absl::Span<const int64> minor_to_major_b = LayoutUtil::MinorToMajor(b);
91 std::vector<int64> major_to_minor_b(minor_to_major_b.rbegin(),
92 minor_to_major_b.rend());
93 for (size_t i = 0; i < permutation.size(); ++i) {
94 permutation[i] = PositionInContainer(major_to_minor_b, major_to_minor_a[i]);
95 }
96
97 std::vector<size_t> segments = ConsecutiveSegments(permutation);
98 if ((3 == segments.size() && 0 == permutation[0]) || 2 == segments.size()) {
99 Shape descending_layout_shape =
100 ShapeUtil::MakeShapeWithDescendingLayoutAndSamePhysicalLayout(a);
101 Shape normalized_shape = MergeDimensions(segments, descending_layout_shape);
102 absl::Span<const int64> normalized_dims =
103 AsInt64Slice(normalized_shape.dimensions());
104 std::vector<int64> dims_021;
105 if (2 == segments.size()) {
106 // The logical component-0 is of size one.
107 dims_021 = {1, normalized_dims[1], normalized_dims[0]};
108 } else {
109 dims_021 = {normalized_dims[0], normalized_dims[2], normalized_dims[1]};
110 }
111
112 return dims_021;
113 }
114
115 return absl::nullopt;
116 }
117
KernelMappingScheme(absl::Span<const int64> dims_in_elems,int64 tile_size_y,int64 tile_size_x,absl::Span<const int64> req_block_sizes,int64 num_threads_y,int64 num_threads_x,llvm::IRBuilder<> * b)118 KernelMappingScheme::KernelMappingScheme(
119 absl::Span<const int64> dims_in_elems, int64 tile_size_y, int64 tile_size_x,
120 absl::Span<const int64> req_block_sizes, int64 num_threads_y,
121 int64 num_threads_x, llvm::IRBuilder<>* b)
122 : b_(b),
123 dims_in_elems_(dims_in_elems.begin(), dims_in_elems.end()),
124 tile_sizes_{1, tile_size_y, tile_size_x},
125 num_threads_x_(num_threads_x),
126 num_threads_y_(num_threads_y),
127 dilated_x_(true) {
128 DCHECK_EQ(dims_in_elems_.size(), 3);
129 DCHECK_EQ(req_block_sizes.size(), 3);
130
131 DCHECK_EQ(tile_size_y % num_threads_y_, 0);
132 DCHECK_EQ(tile_size_x % num_threads_x_, 0);
133
134 dims_in_tiles_ = ElementWiseCeilOfRatio<int64>(dims_in_elems_, tile_sizes_);
135 block_sizes_.reserve(req_block_sizes.size());
136 absl::c_transform(req_block_sizes, dims_in_tiles_,
137 std::back_inserter(block_sizes_),
__anonaaf414bb0202(const int64 requested_size, const int64 max_size) 138 [](const int64 requested_size, const int64 max_size) {
139 return std::min(requested_size, max_size);
140 });
141 dims_in_blocks_ = ElementWiseCeilOfRatio<int64>(dims_in_tiles_, block_sizes_);
142
143 VLOG(10) << "dims_in_elems_ = [" << absl::StrJoin(dims_in_elems_, ",") << "]";
144 VLOG(10) << "dims_in_tiles_ = [" << absl::StrJoin(dims_in_tiles_, ",") << "]";
145 VLOG(10) << "dims_in_blocks_ = [" << absl::StrJoin(dims_in_blocks_, ",")
146 << "]";
147 }
148
GetUnnormalizedIndex(const IrArray::Index & normalized_shape_index,const Shape & unnormalized_shape)149 IrArray::Index KernelMappingScheme::GetUnnormalizedIndex(
150 const IrArray::Index& normalized_shape_index,
151 const Shape& unnormalized_shape) {
152 DCHECK_EQ(normalized_shape_index.size(), dims_in_elems_.size());
153 Shape output_shape = ShapeUtil::MakeShapeWithDescendingLayout(
154 unnormalized_shape.element_type(), GetDimensionsInElements());
155 return GetReshapedIndex(normalized_shape_index, output_shape,
156 unnormalized_shape, b_);
157 }
158
EmitBlockIndex(llvm::Type * index_ty)159 IrArray::Index KernelMappingScheme::EmitBlockIndex(llvm::Type* index_ty) {
160 llvm::Value* block_id = llvm_ir::EmitCallToIntrinsic(
161 llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, {}, {}, b_);
162 llvm_ir::AddRangeMetadata(0, GetNumberOfBlocks(),
163 llvm::cast<llvm::Instruction>(block_id));
164 llvm::Value* linear_block_id =
165 b_->CreateIntCast(block_id, index_ty, /*isSigned=*/true, "block.id.x");
166 return IrArray::Index(linear_block_id,
167 ShapeUtil::MakeShapeWithDescendingLayout(
168 PRED /*arbitrary*/, dims_in_blocks_),
169 b_);
170 }
171
GetTileIndexForBlockOrigin(const IrArray::Index & block_index)172 IrArray::Index KernelMappingScheme::GetTileIndexForBlockOrigin(
173 const IrArray::Index& block_index) {
174 DCHECK_EQ(block_index.size(), block_sizes_.size());
175 std::vector<llvm::Value*> multidim;
176 multidim.reserve(block_sizes_.size());
177 for (int i = 0; i < block_sizes_.size(); ++i) {
178 multidim.push_back(b_->CreateMul(
179 block_index[i],
180 llvm::ConstantInt::get(block_index[i]->getType(), block_sizes_[i]),
181 "block_origin." + std::to_string(i)));
182 }
183 return IrArray::Index(multidim, block_index[0]->getType());
184 }
185
GetElementIndexForTileOrigin(const IrArray::Index & tile_index)186 IrArray::Index KernelMappingScheme::GetElementIndexForTileOrigin(
187 const IrArray::Index& tile_index) {
188 std::vector<llvm::Value*> elem_multi_index = tile_index.multidim();
189 for (int i = DimY; i < DimTot; ++i) {
190 elem_multi_index[i] =
191 b_->CreateMul(tile_index[i],
192 llvm::ConstantInt::get(tile_index[i]->getType(),
193 GetTileSizeForDimension(i)),
194 "tile_origin." + std::to_string(i));
195 }
196 return IrArray::Index(elem_multi_index, tile_index.GetType());
197 }
198
GetSharedMemoryBufferForElementType(llvm::Type * elem_ty,absl::string_view buffer_name)199 llvm::GlobalVariable* KernelMappingScheme::GetSharedMemoryBufferForElementType(
200 llvm::Type* elem_ty, absl::string_view buffer_name) {
201 // If shared memory tranpose is needed, we use square tiles.
202 CHECK_EQ(GetTileSizeForDimensionX(), GetTileSizeForDimensionY());
203
204 // For Nvidia GPUs, the warp size is 32 threads and the shared memory bank is
205 // organized into 32-way. We usually use the warp size or a multiplier or a
206 // the warp size as the size for tiling. This may cause all elements in the
207 // same column of a tile use the same memory bank and therefore shared memory
208 // bank conflicts. Adding 1 to the minor dimension of the shared memory buffer
209 // can reduce such shared memory bank conflicts.
210 llvm::Type* buffer_type = llvm::ArrayType::get(
211 llvm::ArrayType::get(elem_ty, GetTileSizeForDimension(DimX) + 1),
212 GetTileSizeForDimension(DimY));
213 return llvm_ir::AllocateSharedMemoryTile(b_->GetInsertBlock()->getModule(),
214 buffer_type, buffer_name);
215 }
216
217 std::tuple<llvm::Value*, llvm::Value*>
EmitThreadYXCoordinate(llvm::Type * index_ty)218 KernelMappingScheme::EmitThreadYXCoordinate(llvm::Type* index_ty) {
219 // Calculate (y, x) coordinate of the thread in the 2D view of thread block
220 // defined by (num_thread_y, num_thread_x) from thread_id.
221 llvm::CallInst* thread_id_raw = llvm_ir::EmitCallToIntrinsic(
222 llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, {}, {}, b_);
223 llvm_ir::AddRangeMetadata(0, GetThreadsPerBlock(), thread_id_raw);
224 llvm::Value* thread_id_int =
225 b_->CreateIntCast(thread_id_raw, index_ty,
226 /*isSigned=*/true, "thread.id.x");
227 llvm::Value* num_thread_x =
228 llvm::ConstantInt::get(index_ty, GetNumberOfThreadsForDimensionX());
229 llvm::Value* x = b_->CreateURem(thread_id_int, num_thread_x, "thread.x");
230 llvm::Value* y = b_->CreateUDiv(thread_id_int, num_thread_x, "thread.y");
231 return std::make_tuple(y, x);
232 }
233
234 } // namespace llvm_ir
235 } // namespace xla
236