1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h"
17
18 #include <vector>
19
20 // IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
21 #include "absl/strings/str_cat.h"
22 #include "absl/strings/string_view.h"
23 #include "absl/types/span.h"
24 #include "llvm/ADT/APInt.h"
25 #include "llvm/IR/BasicBlock.h"
26 #include "llvm/IR/Constants.h"
27 #include "llvm/IR/DerivedTypes.h"
28 #include "llvm/IR/Instructions.h"
29 #include "llvm/IR/Value.h"
30 #include "tensorflow/compiler/xla/primitive_util.h"
31 #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
32 #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
33 #include "tensorflow/compiler/xla/service/gpu/target_util.h"
34 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
35 #include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
36 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
37 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
38 #include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
39 #include "tensorflow/compiler/xla/shape_util.h"
40 #include "tensorflow/compiler/xla/util.h"
41 #include "tensorflow/core/lib/core/status.h"
42 #include "tensorflow/core/platform/types.h"
43
44 namespace xla {
45 namespace llvm_ir {
46
47 namespace {
48
49 // Adds the inner comparison loop body where we compare elements.
EmitCompareLoopBody(int64 iteration_bound,int64 num_values,llvm::Value * element_pair_index,int64 xor_mask,llvm::Type * index_type,std::function<llvm::Value * (int64 operand,llvm::Value * index)> element_address,std::function<void (int64 operand,llvm::Value * index,llvm::Value * value)> write_element,const EmitCallToNestedComputationCallback & emit_compare_callback,llvm::IRBuilder<> * b,bool needs_bounds_checks=true)50 Status EmitCompareLoopBody(
51 int64 iteration_bound, int64 num_values, llvm::Value* element_pair_index,
52 int64 xor_mask, llvm::Type* index_type,
53 std::function<llvm::Value*(int64 operand, llvm::Value* index)>
54 element_address,
55 std::function<void(int64 operand, llvm::Value* index, llvm::Value* value)>
56 write_element,
57 const EmitCallToNestedComputationCallback& emit_compare_callback,
58 llvm::IRBuilder<>* b, bool needs_bounds_checks = true) {
59 auto index_typed_constant = [&](int64 value) {
60 return llvm::ConstantInt::get(index_type, value);
61 };
62 // The 'xor_mask' determines which elements are compared against each other.
63 // Index 'current_keys_index' will be compared with 'current_keys_index' xor
64 // 'xor_mask'. This means that we will always compare a block of consecutive
65 // elements against elements from the adjacent block of the same size. When
66 // 'xor_mask' is a power of 2, it immediately identifies the size of such a
67 // block. We can also have 'xor_mask' being 2^k - 1 (for some value of k). In
68 // that case, we essentially flip the last 'k' - 1 bits when computing the
69 // position of the element to compare to, so the block size is 2^(k - 1).
70 int64 block_size = xor_mask;
71 // Check if it is a value 2^k - 1.
72 if (xor_mask > 1 && (xor_mask & (xor_mask + 1)) == 0) {
73 block_size = (xor_mask + 1) / 2;
74 }
75 auto current_keys_index = element_pair_index;
76 if (block_size == 1) {
77 // If the block size is 1, we take every second element and compare it to
78 // the next one.
79 current_keys_index =
80 b->CreateMul(current_keys_index, index_typed_constant(2));
81 } else if (block_size * 2 < iteration_bound) {
82 // current_keys_index iterates through the 'left' elements of the element
83 // pairs to be compared. We first need to compute the comparison block to
84 // which the element belongs. The block id of that block is index /
85 // block_size.
86 auto block_id =
87 b->CreateUDiv(current_keys_index, index_typed_constant(block_size));
88 // The index of the 'left' element within its block is simply the remainder
89 // when dividing by 'block_size'.
90 auto index_within_block =
91 b->CreateURem(current_keys_index, index_typed_constant(block_size));
92 // The first element of the 'left' block of elements that is compared
93 // against elements from the adjacent 'right' block of elements is
94 // 'block_id' * (2 * 'block_size').
95 auto first_element_in_block =
96 b->CreateMul(block_id, index_typed_constant(2 * block_size));
97 current_keys_index =
98 b->CreateAdd(first_element_in_block, index_within_block);
99 }
100 auto compare_keys_index =
101 b->CreateXor(current_keys_index, index_typed_constant(xor_mask));
102 // current_keys_index < compare_keys_index
103 llvm::Value* is_smaller_index =
104 b->CreateICmpSLT(current_keys_index, compare_keys_index);
105 // compare_keys_index < iteration_bound
106 llvm::Value* index_is_inbounds = b->CreateICmpSLT(
107 compare_keys_index, index_typed_constant(iteration_bound));
108 llvm::Value* do_comparison =
109 needs_bounds_checks ? b->CreateAnd(is_smaller_index, index_is_inbounds)
110 : b->getInt1(true);
111
112 // if (is_smaller_index && index_is_inbounds)
113 KernelSupportLibrary ksl(b);
114 return ksl.IfWithStatus("smaller_comparison_index", do_comparison, [&]() {
115 std::vector<llvm::Value*> values_to_compare;
116 for (int i = 0; i < num_values; ++i) {
117 values_to_compare.push_back(element_address(i, compare_keys_index));
118 values_to_compare.push_back(element_address(i, current_keys_index));
119 }
120 llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
121 llvm::Value* compare_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
122 llvm_ir::PrimitiveTypeToIrType(PRED, module), "compare_return_buffer",
123 b);
124 TF_RETURN_IF_ERROR(
125 emit_compare_callback(values_to_compare, compare_return_buffer));
126 llvm::Value* result = b->CreateLoad(compare_return_buffer);
127
128 // Check if the 'compare' function returns true.
129 llvm::Value* is_smaller_than =
130 b->CreateICmpNE(result, llvm::ConstantInt::get(result->getType(), 0),
131 "boolean_predicate");
132 ksl.If("is_smaller_than", is_smaller_than, [&]() {
133 for (int64 i = 0; i < num_values; ++i) {
134 // Swap the values.
135 auto value1 = b->CreateLoad(values_to_compare[i * 2]);
136 auto value2 = b->CreateLoad(values_to_compare[i * 2 + 1]);
137 write_element(i, current_keys_index, value1);
138 write_element(i, compare_keys_index, value2);
139 }
140 });
141 return Status::OK();
142 });
143 }
144
EmitTiledCompareLoop(const IrArray::Index & tiled_keys_index,int64 dimension_to_sort,int64 dimension_to_sort_bound,absl::Span<const int64> xor_masks,const std::vector<IrArray> & params,const std::vector<llvm::Value * > & param_shmem_buffers,int64 tile_size,const EmitCallToNestedComputationCallback & emit_compare_callback,llvm::IRBuilder<> * b)145 Status EmitTiledCompareLoop(
146 const IrArray::Index& tiled_keys_index, int64 dimension_to_sort,
147 int64 dimension_to_sort_bound, absl::Span<const int64> xor_masks,
148 const std::vector<IrArray>& params,
149 const std::vector<llvm::Value*>& param_shmem_buffers, int64 tile_size,
150 const EmitCallToNestedComputationCallback& emit_compare_callback,
151 llvm::IRBuilder<>* b) {
152 KernelSupportLibrary ksl(b);
153 llvm::Value* thread_id = gpu::EmitCallToTargetIntrinsic(
154 gpu::TargetIntrinsicID::kThreadIdx, {}, {}, b);
155 llvm_ir::AddRangeMetadata(0, tile_size / 2,
156 llvm::cast<llvm::Instruction>(thread_id));
157 thread_id = b->CreateIntCast(thread_id, tiled_keys_index.GetType(),
158 /*isSigned=*/true, "thread.id.x");
159
160 auto copy_loop_body =
161 [&](std::function<void(llvm::Value * cache_index, llvm::Value * index)>
162 read_or_write) {
163 auto value_one = tiled_keys_index.GetConstantWithIndexType(1);
164 auto current_keys_index =
165 b->CreateShl(tiled_keys_index[dimension_to_sort], value_one);
166 // We want to copy two adjacent elements. We first check whether the
167 // first index position is within bounds.
168 ksl.If(
169 "smaller_keys_index",
170 b->CreateICmpSLT(current_keys_index,
171 tiled_keys_index.GetConstantWithIndexType(
172 dimension_to_sort_bound)),
173 [&]() {
174 auto cache_index = b->CreateShl(thread_id, value_one);
175 read_or_write(cache_index, current_keys_index);
176 // Increment to go to the next index position.
177 current_keys_index = b->CreateAdd(current_keys_index, value_one);
178 // Here we check whether the next index position is within bounds.
179 ksl.If("inner_smaller_keys_index",
180 b->CreateICmpSLT(current_keys_index,
181 tiled_keys_index.GetConstantWithIndexType(
182 dimension_to_sort_bound)),
183 [&]() {
184 cache_index = b->CreateAdd(cache_index, value_one);
185 read_or_write(cache_index, current_keys_index);
186 });
187 });
188 };
189
190 // Copy operand tiles from the operand buffers to shared memory.
191 std::vector<llvm::Value*> keys_multi_index = tiled_keys_index.multidim();
192 for (int64 i = 0; i < params.size(); ++i) {
193 copy_loop_body([&](llvm::Value* cache_index, llvm::Value* index) {
194 keys_multi_index[dimension_to_sort] = index;
195 IrArray::Index keys_index(keys_multi_index, params[i].GetShape(),
196 tiled_keys_index.GetType());
197 auto value = params[i].EmitReadArrayElement(keys_index, b);
198 b->CreateStore(value,
199 b->CreateGEP(param_shmem_buffers[i],
200 {tiled_keys_index.GetConstantWithIndexType(0),
201 cache_index}));
202 });
203 }
204 // Wait until all reads have happened.
205 gpu::EmitCallToTargetIntrinsic(gpu::TargetIntrinsicID::kBarrierId, {}, {}, b);
206
207 // Now emit the bodies of the comparison loops.
208 auto element_address = [&](int64 operand, llvm::Value* index) {
209 auto shared_memory_address =
210 b->CreateGEP(param_shmem_buffers[operand],
211 {tiled_keys_index.GetConstantWithIndexType(0), index});
212 auto ptr_type = shared_memory_address->getType();
213 // We need a generic pointer with address space 0 instead of a pointer to
214 // shared memory (address space 3) so that we can pass it to the comparison
215 // computation.
216 return b->CreateAddrSpaceCast(
217 shared_memory_address,
218 llvm::PointerType::get(ptr_type->getPointerElementType(),
219 /*AddressSpace=*/0));
220 };
221 auto write_element = [&](int64 operand, llvm::Value* index,
222 llvm::Value* value) {
223 b->CreateStore(
224 value,
225 b->CreateGEP(param_shmem_buffers[operand],
226 {tiled_keys_index.GetConstantWithIndexType(0), index}));
227 };
228 for (int64 xor_mask : xor_masks) {
229 // The index of the element pair to be compared within the tile stored in
230 // shared memory. We order the element pairs by the element with the smaller
231 // index.
232 auto element_pair_index = thread_id;
233 // If 'dimension_to_sort_bound' is evenly divisible by 'tile_size', we don't
234 // need any bounds checks.
235 if (dimension_to_sort_bound % tile_size) {
236 // Otherwise we need a bounds check for the last tile. The last tile has
237 // size 'dimension_to_sort_bound' % 'tile_size'.
238 TF_RETURN_IF_ERROR(ksl.IfWithStatus(
239 "is_last_tile",
240 b->CreateICmpUGE(
241 b->CreateMul(tiled_keys_index[dimension_to_sort],
242 tiled_keys_index.GetConstantWithIndexType(2)),
243 tiled_keys_index.GetConstantWithIndexType(
244 RoundDownToNearest(dimension_to_sort_bound, tile_size))),
245 [&]() {
246 return EmitCompareLoopBody(
247 dimension_to_sort_bound % tile_size, params.size(),
248 element_pair_index, xor_mask, tiled_keys_index.GetType(),
249 element_address, write_element, emit_compare_callback, b);
250 },
251 [&]() {
252 return EmitCompareLoopBody(
253 tile_size, params.size(), element_pair_index, xor_mask,
254 tiled_keys_index.GetType(), element_address, write_element,
255 emit_compare_callback, b,
256 /*needs_bounds_checks=*/false);
257 }));
258 } else {
259 TF_RETURN_IF_ERROR(EmitCompareLoopBody(
260 tile_size, params.size(), element_pair_index, xor_mask,
261 tiled_keys_index.GetType(), element_address, write_element,
262 emit_compare_callback, b,
263 /*needs_bounds_checks=*/false));
264 }
265 // Wait until all comparisons have happened.
266 gpu::EmitCallToTargetIntrinsic(gpu::TargetIntrinsicID::kBarrierId, {}, {},
267 b);
268 }
269
270 // Copy the operand tiles back from shared memory to the operand buffers.
271 for (int64 i = 0; i < params.size(); ++i) {
272 copy_loop_body([&](llvm::Value* cache_index, llvm::Value* index) {
273 keys_multi_index[dimension_to_sort] = index;
274 IrArray::Index keys_index(keys_multi_index, params[i].GetShape(),
275 tiled_keys_index.GetType());
276 auto value = b->CreateLoad(b->CreateGEP(
277 param_shmem_buffers[i],
278 {tiled_keys_index.GetConstantWithIndexType(0), cache_index}));
279 params[i].EmitWriteArrayElement(keys_index, value, b);
280 });
281 }
282 // We should normally synchronize here to make sure all writes have happened.
283 // However the very next thing each thread does is reading 2 elements from the
284 // operand buffer and writing it into the same location in shared memory from
285 // which it previously copied it to the operand buffer, and we synchronize
286 // after this has happened. We can be sure that a thread always writes to the
287 // same location in shared memory because we have exactly tile_size / 2 many
288 // threads, and the linear index calculated by ParallelLoopEmitter uses
289 // linear_index = blockIdx.x * blockDim.x + threadIdx.x;
290 return Status::OK();
291 }
292 } // namespace
293
EmitSortInPlace(int64 dimension_to_sort,const std::vector<IrArray> & values_arrays,absl::string_view name,absl::Span<const int64> xor_masks,llvm::IRBuilder<> * b,const gpu::LaunchDimensions & launch_dimensions,int64 num_iterations_in_sort_dim,const int64 tile_size,const EmitCallToNestedComputationCallback & emit_compare_callback)294 Status EmitSortInPlace(
295 int64 dimension_to_sort, const std::vector<IrArray>& values_arrays,
296 absl::string_view name, absl::Span<const int64> xor_masks,
297 llvm::IRBuilder<>* b, const gpu::LaunchDimensions& launch_dimensions,
298 int64 num_iterations_in_sort_dim, const int64 tile_size,
299 const EmitCallToNestedComputationCallback& emit_compare_callback) {
300 // Iterate through the keys shape in physical order, but skip the dimension to
301 // sort and make it the innermost loop which is the loop where the comparisons
302 // happen. In the dimension to sort, if we use tiling, we iterate through it
303 // in tiles of 64 elements each, so we use another loop that happens within
304 // one thread to process this tile worth of data (thereby combining several
305 // comparison stages of the bitonic sort algorithm because they all happen
306 // within those 64 elements and are therefore independent of the other
307 // comparisons).
308
309 const Shape& keys_shape = values_arrays[0].GetShape();
310 int64 rank = keys_shape.rank();
311 int64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort);
312 std::vector<int64> dimensions_in_iteration_order(rank);
313 std::vector<int64> iteration_order_to_logical_order(rank);
314 int64 dim = 0;
315 for (int64 dimension : LayoutUtil::MinorToMajor(keys_shape)) {
316 if (dimension != dimension_to_sort) {
317 dimensions_in_iteration_order[dim] = keys_shape.dimensions(dimension);
318 iteration_order_to_logical_order[dim++] = dimension;
319 }
320 }
321 dimensions_in_iteration_order[dim] = num_iterations_in_sort_dim;
322 iteration_order_to_logical_order[dim] = dimension_to_sort;
323
324 Shape iteration_shape = ShapeUtil::MakeShape(keys_shape.element_type(),
325 dimensions_in_iteration_order);
326
327 // Allocate shared memory for the tiled compare loop.
328 std::vector<llvm::Value*> param_shmem_buffers(values_arrays.size(), nullptr);
329 if (xor_masks.size() > 1) {
330 llvm::Module* module = b->GetInsertBlock()->getParent()->getParent();
331 for (int64 i = 0; i < values_arrays.size(); ++i) {
332 llvm::Type* tile_type = llvm::ArrayType::get(
333 llvm_ir::PrimitiveTypeToIrType(
334 values_arrays[i].GetShape().element_type(), module),
335 tile_size);
336 param_shmem_buffers[i] = llvm_ir::AllocateSharedMemoryTile(
337 module, tile_type, absl::StrCat(name, "_tile_param_", i));
338 }
339 }
340
341 auto compare_loop_body_emitter =
342 [&](const IrArray::Index& tiles_index) -> Status {
343 // Naive C++ code for the inner compare loop:
344 //
345 // for (int64 i = 0; i < dimension_to_sort_bound; ++i) {
346 // int64 j = i ^ xor_mask;
347 // /* emitted in EmitCompareLoopBody() */
348 // if (i < j && j < dimension_to_sort_bound) {
349 // int64 min_key = std::min(keys[i], keys[j]);
350 // keys[j] = std::max(keys[i], keys[j]);
351 // keys[i] = min_key;
352 // }
353 // }
354 //
355 // This follows the algorithm described on Wikipedia:
356 // https://en.wikipedia.org/wiki/Bitonic_sorter
357 std::vector<llvm::Value*> keys_multi_index(rank);
358 for (int64 i = 0; i < rank; ++i) {
359 keys_multi_index[iteration_order_to_logical_order[i]] = tiles_index[i];
360 }
361 if (xor_masks.size() > 1) {
362 IrArray::Index keys_index(keys_multi_index, values_arrays[0].GetShape(),
363 tiles_index.GetType());
364 TF_RETURN_IF_ERROR(EmitTiledCompareLoop(
365 keys_index, dimension_to_sort, dimension_to_sort_bound, xor_masks,
366 values_arrays, param_shmem_buffers, tile_size, emit_compare_callback,
367 b));
368 } else {
369 auto element_address = [&](int64 operand, llvm::Value* index) {
370 keys_multi_index[dimension_to_sort] = index;
371 IrArray::Index keys_index(keys_multi_index,
372 values_arrays[operand].GetShape(),
373 tiles_index.GetType());
374 return values_arrays[operand].EmitArrayElementAddress(keys_index, b);
375 };
376 auto write_element = [&](int64 operand, llvm::Value* index,
377 llvm::Value* value) {
378 keys_multi_index[dimension_to_sort] = index;
379 IrArray::Index keys_index(keys_multi_index,
380 values_arrays[operand].GetShape(),
381 tiles_index.GetType());
382 values_arrays[operand].EmitWriteArrayElement(keys_index, value, b);
383 };
384 TF_RETURN_IF_ERROR(EmitCompareLoopBody(
385 dimension_to_sort_bound, values_arrays.size(), tiles_index[rank - 1],
386 xor_masks[0], tiles_index.GetType(), element_address, write_element,
387 emit_compare_callback, b));
388 }
389 return Status::OK();
390 };
391 return gpu::ParallelLoopEmitter(compare_loop_body_emitter, iteration_shape,
392 launch_dimensions, b)
393 .EmitLoop(name);
394 }
395
396 } // namespace llvm_ir
397 } // namespace xla
398