1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
17
18 #include <algorithm>
19 #include <cstring>
20 #include <iterator>
21 #include <memory>
22 #include <string>
23 #include <vector>
24
25 #include "absl/algorithm/container.h"
26 #include "absl/container/inlined_vector.h"
27 #include "absl/memory/memory.h"
28 #include "absl/strings/str_cat.h"
29 #include "absl/types/optional.h"
30 #include "absl/types/span.h"
31 #include "llvm/ADT/StringRef.h"
32 #include "llvm/IR/BasicBlock.h"
33 #include "llvm/IR/Function.h"
34 #include "llvm/IR/IRBuilder.h"
35 #include "llvm/IR/Instructions.h"
36 #include "llvm/IR/LLVMContext.h"
37 #include "llvm/IR/Module.h"
38 #include "tensorflow/compiler/xla/layout_util.h"
39 #include "tensorflow/compiler/xla/literal.h"
40 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
41 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
42 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
43 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
44 #include "tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h"
45 #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h"
46 #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h"
47 #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h"
48 #include "tensorflow/compiler/xla/service/gpu/for_thunk.h"
49 #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
50 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"
51 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
52 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
53 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
54 #include "tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h"
55 #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h"
56 #include "tensorflow/compiler/xla/service/gpu/memset_thunk.h"
57 #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
58 #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
59 #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
60 #include "tensorflow/compiler/xla/service/gpu/replica_id_thunk.h"
61 #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
62 #include "tensorflow/compiler/xla/service/gpu/target_util.h"
63 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
64 #include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h"
65 #include "tensorflow/compiler/xla/service/gpu/while_thunk.h"
66 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
67 #include "tensorflow/compiler/xla/service/hlo_computation.h"
68 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
69 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
70 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
71 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
72 #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
73 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
74 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
75 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
76 #include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h"
77 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
78 #include "tensorflow/compiler/xla/service/name_uniquer.h"
79 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
80 #include "tensorflow/compiler/xla/service/while_loop_analysis.h"
81 #include "tensorflow/compiler/xla/shape_util.h"
82 #include "tensorflow/compiler/xla/status_macros.h"
83 #include "tensorflow/compiler/xla/types.h"
84 #include "tensorflow/compiler/xla/util.h"
85 #include "tensorflow/compiler/xla/window_util.h"
86 #include "tensorflow/compiler/xla/xla_data.pb.h"
87 #include "tensorflow/core/lib/core/bits.h"
88 #include "tensorflow/core/lib/core/status.h"
89 #include "tensorflow/core/platform/logging.h"
90
91 namespace xla {
92 namespace gpu {
93
94 namespace {
95
96 using absl::InlinedVector;
97 using absl::nullopt;
98 using absl::optional;
99 using absl::StrCat;
100 using llvm_ir::IrArray;
101 using llvm_ir::IrName;
102
103 const auto kDimX = KernelMappingScheme::DimX;
104 const auto kDimY = KernelMappingScheme::DimY;
105 const auto kDimZ = KernelMappingScheme::DimZ;
106 const auto kDimTot = KernelMappingScheme::DimTot;
107
108 // If a dimensions is smaller than this, untiled transposition may be more
109 // efficient.
110 const int64 kMinDimensionToTransposeTiled = 16;
111
112 // Updates the launch dimensions in "thunk" and annotate the launch dimensions
113 // of the corresponding IR kernel in "llvm_module".
114 // Precondition: "thunk" must be a KernelThunk.
UpdateLaunchDimensions(const LaunchDimensions & launch_dims,Thunk * thunk,llvm::Module * llvm_module)115 void UpdateLaunchDimensions(const LaunchDimensions& launch_dims, Thunk* thunk,
116 llvm::Module* llvm_module) {
117 CHECK(Thunk::Kind::kKernel == thunk->kind());
118 KernelThunk* kernel_thunk = static_cast<KernelThunk*>(thunk);
119 kernel_thunk->SetLaunchDimensions(launch_dims);
120
121 // Add __launch_bounds__ to metadata. This limits registers per thread to
122 // avoid out-of-resources launching errors.
123 llvm::NamedMDNode* nvvm_annotations_node =
124 llvm_module->getOrInsertNamedMetadata("nvvm.annotations");
125 llvm::Function* ir_kernel =
126 llvm_module->getFunction(kernel_thunk->kernel_name().c_str());
127 llvm::LLVMContext& llvm_context = llvm_module->getContext();
128 llvm::ConstantInt* threads_per_block_ir_value = llvm::ConstantInt::get(
129 llvm::IntegerType::get(llvm_context, /*NumBits=*/32),
130 launch_dims.threads_per_block());
131 // Our launch bounds are exact, so we can specify them as reqntidx rather than
132 // maxntidx.
133 nvvm_annotations_node->addOperand(llvm::MDNode::get(
134 llvm_context,
135 {llvm::ConstantAsMetadata::get(ir_kernel),
136 llvm::MDString::get(llvm_context, "reqntidx"),
137 llvm::ConstantAsMetadata::get(threads_per_block_ir_value)}));
138 }
139
140 } // namespace
141
IrEmitterUnnested(const HloModuleConfig & hlo_module_config,const HloComputation * hlo_computation,IrEmitterContext * ir_emitter_context)142 IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
143 const HloComputation* hlo_computation,
144 IrEmitterContext* ir_emitter_context)
145 : IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/false),
146 hlo_computation_(hlo_computation) {
147 // Initialize thunk_sequence_ to an empty list of thunks.
148 thunk_sequence_.reset(new ThunkSequence());
149 }
150
Postprocess(HloInstruction * hlo)151 Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) {
152 bindings_.UnbindAllLocalIrValues();
153 return DfsHloVisitor::Postprocess(hlo);
154 }
155
BuildKernelPrototype(const HloInstruction & inst,absl::Span<const BufferAllocation * const> args)156 llvm::Function* IrEmitterUnnested::BuildKernelPrototype(
157 const HloInstruction& inst,
158 absl::Span<const BufferAllocation* const> args) {
159 // Compute the kernel name. The opcode string may contain "-" which cannot be
160 // in a PTX function name, so sanitize the name before uniquifying it.
161 string kernel_name = ir_emitter_context_->name_uniquer()->GetUniqueName(
162 llvm_ir::SanitizeFunctionName(inst.name()));
163
164 // Create the kernel and add it to the module.
165 llvm::Module* module = ir_emitter_context_->llvm_module();
166 llvm::LLVMContext& context = module->getContext();
167 llvm::FunctionType* kernel_type = llvm::FunctionType::get(
168 /*Result=*/llvm::Type::getVoidTy(context),
169 std::vector<llvm::Type*>(args.size(), b_.getInt8PtrTy()),
170 /*isVarArg=*/false);
171 llvm::Function* kernel =
172 llvm::Function::Create(kernel_type, llvm::GlobalValue::ExternalLinkage,
173 kernel_name.c_str(), module);
174
175 // Add dereferenceable and alignment information to each of the kernel's
176 // parameters.
177 auto arg_it = kernel->arg_begin();
178 for (size_t arg_no = 0; arg_no < args.size(); ++arg_no) {
179 const BufferAllocation* alloc = args[arg_no];
180 llvm::Argument* fn_arg = &*arg_it;
181 ++arg_it;
182
183 kernel->addDereferenceableAttr(arg_no + 1, alloc->size());
184
185 const int64 alignment = [&] {
186 if (alloc->is_entry_computation_parameter()) {
187 return kEntryParameterAlignBytes;
188 } else if (alloc->is_constant()) {
189 return kConstantBufferAlignBytes;
190 } else {
191 return kXlaAllocatedBufferAlignBytes;
192 }
193 }();
194
195 kernel->addParamAttr(
196 arg_no,
197 llvm::Attribute::get(context, llvm::Attribute::Alignment, alignment));
198
199 if (alloc->IsPreallocatedTempBuffer()) {
200 fn_arg->setName("temp_buf");
201 } else {
202 fn_arg->setName(StrCat("alloc", alloc->index()));
203 }
204 }
205
206 AnnotateFunctionAsGpuKernel(module, kernel, &b_);
207
208 // TODO(b/65380986): Investigate if adding fast math flags for generated
209 // kernels makes sense.
210
211 // Update the insert point to the entry basic block.
212 llvm::BasicBlock* entry_bb =
213 llvm::BasicBlock::Create(context, /*Name=*/"entry", /*Parent=*/kernel);
214
215 // Emit a "return void" at entry_bb's end, and set the insert point before
216 // that return instruction.
217 b_.SetInsertPoint(llvm::ReturnInst::Create(context, entry_bb));
218
219 return kernel;
220 }
221
222 namespace {
223 // Computes the maximum valid unroll factor for a given instruction.
ComputeMaxUnrollFactor(const HloInstruction * hlo)224 int ComputeMaxUnrollFactor(const HloInstruction* hlo) {
225 int max_unroll_factor = hlo->GetModule()
226 ->config()
227 .debug_options()
228 .xla_gpu_max_kernel_unroll_factor();
229
230 // Find the largest possible power of two to unroll by.
231 // TODO(kramerb): Make this smarter.
232 const Shape& element_shape = hlo->IsMultiOutputFusion()
233 ? ShapeUtil::GetSubshape(hlo->shape(), {0})
234 : hlo->shape();
235 int64 num_elements = ShapeUtil::ElementsIn(element_shape);
236 for (int i = max_unroll_factor; i > 1; i /= 2) {
237 if (num_elements % i == 0) {
238 return i;
239 }
240 }
241
242 // Cannot unroll.
243 return 1;
244 }
245
246 // Returns the llvm type for the indices used in the kernel that contains the
247 // hlo instruction. Such indices include the index for the parallel loop and
248 // the indices for the tensors accessed by the kernel. The return type is i32
249 // iff the following conditions are met:
250 // . The launch_size of the kernel is within the range of i32.
251 // . The sizes of all the tensors accessed within the kernel are within the
252 // range of i32.
253 // Otherwise, the return type is i64.
GetIndexTypeForKernel(const HloInstruction * hlo,int64 launch_size,llvm::IRBuilder<> * b)254 llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size,
255 llvm::IRBuilder<>* b) {
256 // Find the unnested hlo instruction for which the kernel is generated for.
257 const HloInstruction* unnested_hlo = hlo;
258 const HloComputation* computation = hlo->parent();
259 if (computation->IsFusionComputation()) {
260 unnested_hlo = computation->FusionInstruction();
261 }
262
263 auto shape_in_range = [&](const Shape& s) {
264 bool in_range = true;
265 ShapeUtil::ForEachSubshape(s, [&](const Shape& sub_shape,
266 const ShapeIndex& /*index*/) {
267 if (sub_shape.IsArray() && !IsInt32(ShapeUtil::ElementsIn(sub_shape))) {
268 in_range = false;
269 }
270 });
271
272 return in_range;
273 };
274
275 llvm::Type* i64_ty = b->getInt64Ty();
276 // Check launch dimension
277 if (!IsInt32(launch_size)) {
278 return i64_ty;
279 }
280
281 // Check the size of result tensors
282 if (!shape_in_range(unnested_hlo->shape())) {
283 return i64_ty;
284 }
285
286 auto hlo_shape_in_range = [&](const HloInstruction* operand) -> bool {
287 return shape_in_range(operand->shape());
288 };
289
290 // Check the size of input tensors
291 if (!absl::c_all_of(unnested_hlo->operands(), hlo_shape_in_range)) {
292 return i64_ty;
293 }
294
295 // Check the size of the internal result tensors
296 if (unnested_hlo->opcode() == HloOpcode::kFusion) {
297 if (!absl::c_all_of(
298 unnested_hlo->fused_instructions_computation()->instructions(),
299 hlo_shape_in_range)) {
300 return i64_ty;
301 }
302 }
303
304 return b->getInt32Ty();
305 }
306
307 // Gets the input shape of the ROOT slices, which will be used as the kernel
308 // launch dims. The slice input fusion requires the input shapes of the ROOT
309 // slices to be the same although the (slice) output shapes can be different.
310 //
311 // Returns the input shape of the ROOT slices if all the input shapes of ROOT
312 // slices are the same and the slices are non-strided. Otherwise, returns
313 // FailedPrecondition.
GetConsistentInputShapeForRootSlices(const HloInstruction & fusion)314 StatusOr<Shape> GetConsistentInputShapeForRootSlices(
315 const HloInstruction& fusion) {
316 if (!IsInputFusibleSlices(fusion, /*verify_no_strides=*/true)) {
317 return FailedPrecondition(
318 "Unsupported root for slice input fusion. "
319 "Only non-strided slices are supported.");
320 }
321
322 const HloInstruction& root = *fusion.fused_expression_root();
323 if (root.opcode() == HloOpcode::kSlice) {
324 return root.operands()[0]->shape();
325 }
326
327 CHECK_EQ(root.opcode(), HloOpcode::kTuple);
328 const Shape& first_slice_operand_shape =
329 root.operands()[0]->operands()[0]->shape();
330 for (size_t i = 1; i < root.operands().size(); ++i) {
331 const HloInstruction* slice = root.operands()[i];
332 const Shape& operand_shape = slice->operands()[0]->shape();
333 if (!ShapeUtil::EqualIgnoringElementType(first_slice_operand_shape,
334 operand_shape)) {
335 return FailedPrecondition(
336 "Fused slices do not have the same input shape, fused computation = "
337 "%s.",
338 root.parent()->name());
339 }
340 }
341
342 return first_slice_operand_shape;
343 }
344
345 } // namespace
346
DefaultAction(HloInstruction * hlo)347 Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) {
348 return IrEmitter::DefaultAction(hlo);
349 }
350
HandleDot(HloInstruction * dot)351 Status IrEmitterUnnested::HandleDot(HloInstruction* dot) {
352 AddThunkToThunkSequence(
353 BuildKernelThunk(dot, /*implements_whole_instruction=*/true));
354 return IrEmitter::HandleDot(dot);
355 }
356
HandleConditional(HloInstruction * conditional)357 Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) {
358 AddThunkToThunkSequence(BuildConditionalThunk(conditional));
359 return Status::OK();
360 }
361
HandleConvolution(HloInstruction * convolution)362 Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution) {
363 AddThunkToThunkSequence(
364 BuildKernelThunk(convolution, /*implements_whole_instruction=*/true));
365 return IrEmitter::HandleConvolution(convolution);
366 }
367
HandleCustomCall(HloInstruction * custom_call)368 Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
369 return ThunkEmitter(this).HandleCustomCall(custom_call);
370 }
371
HandleFft(HloInstruction * fft)372 Status IrEmitterUnnested::HandleFft(HloInstruction* fft) {
373 return ThunkEmitter(this).HandleFft(fft);
374 }
375
HandleTriangularSolve(HloInstruction * hlo)376 Status IrEmitterUnnested::HandleTriangularSolve(HloInstruction* hlo) {
377 return ThunkEmitter(this).HandleTriangularSolve(hlo);
378 }
379
HandleFusion(HloInstruction * fusion)380 Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
381 HloInstruction* root = fusion->fused_expression_root();
382 if (fusion->IsInputFusion()) {
383 switch (root->opcode()) {
384 case HloOpcode::kScatter: {
385 std::vector<std::unique_ptr<Thunk>> thunks;
386 // The initialization from 'operand' is using different loop bounds, so
387 // emit it in a separate kernel. Treat it like a loop fusion, writing to
388 // the output buffer.
389 {
390 int unroll_factor = ComputeMaxUnrollFactor(fusion);
391 thunks.push_back(BuildKernelThunk(
392 fusion, /*implements_whole_instruction=*/false, unroll_factor));
393 GpuElementalIrEmitter operand_elemental_emitter(
394 hlo_module_config_, ir_emitter_context_->llvm_module(), &b_,
395 GetNestedComputer());
396 FusedIrEmitter operand_fused_emitter(
397 GetGeneratorForOperandIrArrays(fusion),
398 &operand_elemental_emitter);
399 TF_RETURN_IF_ERROR(
400 root->mutable_operand(0)->Accept(&operand_fused_emitter));
401
402 TF_RETURN_IF_ERROR(EmitTargetElementLoopInThunk(
403 *fusion, operand_fused_emitter.GetGenerator(root->operand(0)),
404 static_cast<KernelThunk*>(thunks.back().get())));
405 }
406
407 // Now build the actual scatter, reading and writing to the freshly
408 // filled output buffer.
409 {
410 thunks.push_back(
411 BuildKernelThunk(fusion,
412 /*implements_whole_instruction=*/false));
413 // Spin up a new fused emitter for the scatter kernel and emit it.
414 GpuElementalIrEmitter scatter_elemental_emitter(
415 hlo_module_config_, ir_emitter_context_->llvm_module(), &b_,
416 GetNestedComputer());
417 FusedIrEmitter scatter_fused_emitter(
418 GetGeneratorForOperandIrArrays(fusion),
419 &scatter_elemental_emitter);
420 TF_RETURN_IF_ERROR(root->Accept(&scatter_fused_emitter));
421 TF_RETURN_IF_ERROR(EmitScatter(
422 thunks.back().get(), root,
423 /*scatter_indices_gen=*/
424 scatter_fused_emitter.GetGenerator(root->operand(1)),
425 /*updates_gen=*/
426 scatter_fused_emitter.GetGenerator(root->operand(2))));
427 }
428 AddThunkToThunkSequence(
429 absl::make_unique<SequentialThunk>(std::move(thunks), fusion));
430 return Status::OK();
431 }
432 // In the case of root tuple, it can be either reduce or slice input
433 // fusion.
434 case HloOpcode::kTuple: {
435 if (IsInputFusibleSlices(*fusion)) {
436 return EmitInputFusibleNonStridedSlices(fusion);
437 }
438
439 CHECK_GE(root->operand_count(), 1);
440 return EmitReductionFromOrToContiguousDimensions(fusion,
441 root->operands());
442 }
443 case HloOpcode::kReduce: {
444 // HandleFusion specializes reduction from a multi-dimensional array to
445 // a 1D array. The specialized version requires a initializer thunk that
446 // initializes the output array to the initial value of the reduce.
447 if (root->shape().IsTuple()) {
448 // TODO(b/129089333): Support tiled vectorized variadic reduce.
449 return Unimplemented(
450 "Vectorized variadic reduce is not supported on GPU");
451 }
452 return EmitReductionFromOrToContiguousDimensions(fusion, {root});
453 }
454 case HloOpcode::kSlice: {
455 return EmitInputFusibleNonStridedSlices(fusion);
456 }
457 default:
458 LOG(FATAL) << "Bad opcode for input fusion: "
459 << fusion->fused_expression_root()->opcode();
460 }
461 } else if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(
462 fusion, ir_emitter_context_->buffer_assignment())) {
463 // Fusion node with dynamic-update-slice as the root where the op's input
464 // (i.e. array to update) shares the same slice as its output. In this case
465 // we have a special algorithm that modifies the output in place without
466 // touching the un-updated elements.
467
468 // Set up kernel thunk and fused ir emitter.
469 std::unique_ptr<KernelThunk> fusion_thunk =
470 BuildKernelThunk(fusion, /*implements_whole_instruction=*/true);
471 GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
472 ir_emitter_context_->llvm_module(),
473 &b_, GetNestedComputer());
474
475 // Shape of the dynamic-update-slice's "update" operand.
476 Shape update_shape = root->operand(1)->shape();
477
478 // Array to write into. Because this is an in-place operation, this is the
479 // same as operand 0's array.
480 IrArray output_array = GetIrArray(*fusion, *fusion);
481
482 LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
483 update_shape, ir_emitter_context_->device_description());
484 UpdateLaunchDimensions(launch_dimensions, fusion_thunk.get(),
485 ir_emitter_context_->llvm_module());
486 AddThunkToThunkSequence(std::move(fusion_thunk));
487
488 return llvm_ir::EmitParallelFusedDynamicUpdateSliceInPlace(
489 fusion, GetGeneratorForOperandIrArrays(fusion), output_array,
490 &elemental_emitter, launch_dimensions, &b_);
491 }
492
493 CHECK_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kLoop)
494 << ": " << fusion->ToString();
495
496 if (CheckAndEmitHloWithTile021(fusion)) {
497 return Status::OK();
498 }
499
500 return IrEmitter::HandleFusion(fusion);
501 }
502
HandleCopy(HloInstruction * copy)503 Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
504 CHECK(ShapeUtil::Compatible(copy->operand(0)->shape(), copy->shape()));
505 const BufferAssignment& buffer_assignment =
506 ir_emitter_context_->buffer_assignment();
507 if (LayoutUtil::Equal(copy->operand(0)->shape().layout(),
508 copy->shape().layout()) &&
509 buffer_assignment.GetUniqueTopLevelSlice(copy->operand(0)).ok()) {
510 // Copy the operand into the output if it's not the same buffer already.
511 auto operand_buffer = GetAllocationSlice(*copy->operand(0));
512 auto destination_buffer = GetAllocationSlice(*copy);
513 if (operand_buffer != destination_buffer) {
514 AddThunkToThunkSequence(absl::make_unique<DeviceToDeviceCopyThunk>(
515 /*source_address=*/operand_buffer,
516 /*destination_buffer=*/destination_buffer,
517 /*mem_size=*/
518 ByteSizeOf(copy->operand(0)->shape()), copy));
519 }
520 return Status::OK();
521 }
522 if (CheckAndEmitHloWithTile021(copy)) {
523 return Status::OK();
524 }
525
526 return IrEmitter::HandleCopy(copy);
527 }
528
EmitExtraOutputsForReduce(const HloInstruction * unnested_hlo,const IrArray::Index & index,bool use_linear_index,absl::Span<const std::pair<llvm_ir::ElementGenerator,ShapeIndex>> extra_output_gens)529 Status IrEmitterUnnested::EmitExtraOutputsForReduce(
530 const HloInstruction* unnested_hlo, const IrArray::Index& index,
531 bool use_linear_index,
532 absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
533 extra_output_gens) {
534 for (int i = 0; i != extra_output_gens.size(); ++i) {
535 llvm::Value* extra_output_address =
536 GetIrArray(*unnested_hlo, *unnested_hlo, extra_output_gens[i].second)
537 .EmitArrayElementAddress(index, &b_, "extra_output_element_address",
538 use_linear_index);
539 TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value,
540 extra_output_gens[i].first(index));
541 Store(extra_output_ir_value, extra_output_address);
542 }
543 return Status::OK();
544 }
545
HandleReduce(HloInstruction * reduce)546 Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
547 if (IsReductionFromOrToContiguousDimensions(*reduce) &&
548 reduce->shape().IsArray()) {
549 return EmitReductionFromOrToContiguousDimensions(reduce, {reduce});
550 }
551
552 return IrEmitter::HandleReduce(reduce);
553 }
554
HandleTuple(HloInstruction * tuple)555 Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) {
556 // For the root node of the entry computation we can elide writing the tuple
557 // buffer. We can always figure out the contents of the tuples from buffer
558 // assignment because we insert copies to ensure non-ambiguous output buffers.
559 // GpuExecutable never reads the tuple buffer.
560 if (tuple ==
561 tuple->parent()->parent()->entry_computation()->root_instruction()) {
562 return Status::OK();
563 }
564 bool all_tuple_elements_have_buffer =
565 absl::c_all_of(tuple->operands(), [&](HloInstruction* tuple_element) {
566 return ir_emitter_context_->buffer_assignment()
567 .GetUniqueTopLevelSlice(tuple_element)
568 .ok();
569 });
570 // TODO(b/111689850): This logic isn't quite correct.
571 //
572 // Tuples (especially tuples that are the final result of a computation) can
573 // be so huge that if we were to emit a kernel that took each tuple element as
574 // a parameter, we would exceed the max allowable number of parameters to a
575 // GPU kernel, b/31336476. As an optimization, if all tuple elements have a
576 // buffer, we collect their buffer addresses in a host array, and then copy
577 // that array to the tuple's buffer.
578 //
579 // Some tuple elements might not have an unambiguous buffer (like the result
580 // of a select-tuple). In that case, we fall back to emitting kernels which
581 // have access to their buffer addresses in code.
582 if (all_tuple_elements_have_buffer) {
583 std::vector<BufferAllocation::Slice> tuple_element_buffers;
584 for (const HloInstruction* tuple_element : tuple->operands()) {
585 tuple_element_buffers.push_back(GetAllocationSlice(*tuple_element));
586 }
587 AddThunkToThunkSequence(absl::make_unique<TupleThunk>(
588 tuple_element_buffers, GetAllocationSlice(*tuple), tuple));
589 return Status::OK();
590 }
591 AddThunkToThunkSequence(
592 BuildKernelThunk(tuple, /*implements_whole_instruction=*/true));
593 return IrEmitter::HandleTuple(tuple);
594 }
595
HandleGetTupleElement(HloInstruction *)596 Status IrEmitterUnnested::HandleGetTupleElement(HloInstruction*) {
597 // GetTupleElement IR is emitted in the IR context of the user instruction,
598 // and so we do not build a kernel for GetTupleElement instructions.
599 return Status::OK();
600 }
601
HandleSelectAndScatter(HloInstruction * select_and_scatter)602 Status IrEmitterUnnested::HandleSelectAndScatter(
603 HloInstruction* select_and_scatter) {
604 CHECK_EQ(select_and_scatter->operand_count(), 3);
605 const auto* operand = select_and_scatter->operand(0);
606 const auto* source = select_and_scatter->operand(1);
607 const Window& window = select_and_scatter->window();
608 PrimitiveType operand_element_type = operand->shape().element_type();
609 const int64 rank = operand->shape().rank();
610 CHECK_EQ(rank, source->shape().rank());
611 CHECK_EQ(rank, window.dimensions_size());
612
613 TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> initializer_thunk,
614 BuildInitializerThunk(select_and_scatter));
615 std::vector<std::unique_ptr<Thunk>> thunks;
616 thunks.push_back(std::move(initializer_thunk));
617 thunks.push_back(BuildKernelThunk(select_and_scatter,
618 /*implements_whole_instruction=*/false));
619 std::unique_ptr<SequentialThunk> select_and_scatter_thunk =
620 absl::make_unique<SequentialThunk>(std::move(thunks), select_and_scatter);
621
622 // TODO(b/31410564): Implement dilation rate for select-and-scatter.
623 if (window_util::HasDilation(window)) {
624 return Unimplemented(
625 "Dilation for SelectAndScatter not implemented on GPU.");
626 }
627
628 LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
629 source->shape(), ir_emitter_context_->device_description());
630 llvm::Type* index_type = GetIndexTypeForKernel(
631 select_and_scatter, launch_dimensions.launch_bound(), &b_);
632 auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
633 return llvm::ConstantInt::get(index_type, c);
634 };
635
636 // kSelectAndScatter is implemented as two kernel launches: the first launch
637 // initializes the output array to the given initial value,
638 // and the second accumulates the "source" matrix to the
639 // selected elements in the output array. The first launch is already
640 // implemented by the initializer thunk generated earlier, so this function
641 // only needs to take care of the select-and-scatter part.
642 //
643 // Pseudo code for select-and-scatter:
644 //
645 // for (coordinates S in the source): # This loop is parallel.
646 // initialized_flag = false
647 // for (coordinates W in the window):
648 // I = S * stride + W - pad_low
649 // if I within bounds of operand:
650 // if !(initialized_flag and select(selected_value, operand(I))):
651 // selected_value = operand(I)
652 // selected_index = I
653 // initialized_flag = true
654 // output(selected_index) = scatter(output(selected_index), source(S))
655 auto loop_body_emitter = [=](const IrArray::Index& source_index) -> Status {
656 // Allocate space to keep the currently selected value, its index, and a
657 // boolean flag if the value is initialized. The initialized_flag is set
658 // false.
659 llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry(
660 llvm_ir::PrimitiveTypeToIrType(operand_element_type,
661 ir_emitter_context_->llvm_module()),
662 "selected_value_address", &b_);
663 llvm::Value* selected_index_address =
664 llvm_ir::EmitAllocaAtFunctionEntryWithCount(
665 index_type, index_typed_constant(rank), "selected_index_address",
666 &b_);
667 llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry(
668 b_.getInt1Ty(), "initialized_flag_address", &b_);
669 Store(b_.getInt1(false), initialized_flag_address);
670
671 // Create the inner loop to iterate over the window.
672 llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "inner"), &b_,
673 index_type);
674 DimensionVector window_size;
675 for (const auto& dim : window.dimensions()) {
676 window_size.push_back(dim.size());
677 CHECK_GT(dim.size(), 0);
678 }
679 const IrArray::Index window_index = window_loops.AddLoopsForShape(
680 ShapeUtil::MakeShape(operand_element_type, window_size), "window");
681 llvm_ir::SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(),
682 &b_);
683
684 // Compute the operand index to visit and evaluate the condition whether the
685 // operand index is within the bounds. The unsigned comparison includes
686 // checking whether the operand index >= 0.
687 std::vector<llvm::Value*> operand_multi_index(source_index.size());
688 llvm::Value* in_bounds_condition = b_.getInt1(true);
689 for (int64 i = 0; i < rank; ++i) {
690 llvm::Value* strided_index = NSWMul(
691 source_index[i], index_typed_constant(window.dimensions(i).stride()));
692 operand_multi_index[i] =
693 NSWSub(NSWAdd(strided_index, window_index[i]),
694 index_typed_constant(window.dimensions(i).padding_low()));
695 llvm::Value* index_condition = ICmpULT(
696 operand_multi_index[i],
697 index_typed_constant(ShapeUtil::GetDimension(operand->shape(), i)));
698 in_bounds_condition = And(in_bounds_condition, index_condition);
699 }
700 CHECK(in_bounds_condition != nullptr);
701
702 // Only need to do something if the operand index is within the bounds.
703 // First check if the initialized_flag is set.
704 llvm_ir::LlvmIfData if_in_bounds =
705 llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
706 llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, &b_);
707 llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse(
708 Load(initialized_flag_address), "initialized", &b_);
709
710 // If the initialized_flag is false, initialize the selected value and index
711 // with the currently visiting operand.
712 llvm_ir::SetToFirstInsertPoint(if_initialized.false_block, &b_);
713 const auto save_operand_index = [&](const IrArray::Index& operand_index) {
714 for (int64 i = 0; i < rank; ++i) {
715 llvm::Value* selected_index_address_slot =
716 InBoundsGEP(selected_index_address, {b_.getInt32(i)});
717 Store(operand_index[i], selected_index_address_slot);
718 }
719 };
720 IrArray operand_array = GetIrArray(*operand, *select_and_scatter);
721 IrArray::Index operand_index(operand_multi_index, operand->shape(),
722 index_type);
723 llvm::Value* operand_data =
724 operand_array.EmitReadArrayElement(operand_index, &b_);
725 Store(operand_data, selected_value_address);
726 save_operand_index(operand_index);
727 Store(b_.getInt1(true), initialized_flag_address);
728
729 // If the initialized_flag is true, call the `select` function to
730 // potentially update the selected value and index with the currently
731 // visiting operand.
732 llvm_ir::SetToFirstInsertPoint(if_initialized.true_block, &b_);
733 llvm::Value* operand_address =
734 operand_array.EmitArrayElementAddress(operand_index, &b_);
735 llvm::Value* select_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
736 llvm_ir::PrimitiveTypeToIrType(PRED,
737 ir_emitter_context_->llvm_module()),
738 "select_return_buffer", &b_);
739 TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
740 *select_and_scatter->select(),
741 {selected_value_address, operand_address}, select_return_buffer));
742 llvm::Value* result = Load(select_return_buffer);
743
744 // If the 'select' function returns false, update the selected value and the
745 // index to the currently visiting operand.
746 llvm::Value* cond = ICmpNE(
747 result,
748 llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(
749 PRED, ir_emitter_context_->llvm_module()),
750 0),
751 "boolean_predicate");
752 llvm_ir::LlvmIfData if_select_lhs =
753 llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_);
754 llvm_ir::SetToFirstInsertPoint(if_select_lhs.false_block, &b_);
755 Store(Load(operand_address), selected_value_address);
756 save_operand_index(operand_index);
757
758 // After iterating over the window elements, scatter the source element to
759 // the selected index of the output. The value we store at the output
760 // location is computed by calling the `scatter` function with the source
761 // value and the current output value.
762 llvm_ir::SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(),
763 &b_);
764 std::vector<llvm::Value*> selected_multi_index;
765 for (int64 i = 0; i < rank; ++i) {
766 llvm::Value* selected_index_address_slot =
767 InBoundsGEP(selected_index_address, {b_.getInt32(i)});
768 selected_multi_index.push_back(Load(selected_index_address_slot));
769 }
770 llvm::Value* source_value_address =
771 GetIrArray(*source, *select_and_scatter)
772 .EmitArrayElementAddress(source_index, &b_);
773 IrArray::Index selected_index(selected_multi_index,
774 select_and_scatter->shape(),
775 operand_index.GetType());
776 llvm::Value* output_value_address =
777 GetIrArray(*select_and_scatter, *select_and_scatter)
778 .EmitArrayElementAddress(selected_index, &b_);
779 return EmitAtomicOperationForNestedComputation(
780 *select_and_scatter->scatter(), output_value_address,
781 source_value_address);
782 };
783
784 UpdateLaunchDimensions(
785 launch_dimensions,
786 // IrEmitterUnnested implements kSelectAndScatter as a SequentialThunk
787 // consisting of two thunks, an initializer KernelThunk that initializes
788 // the output and another KernelThunk that accumulates the scattered
789 // elements.
790 select_and_scatter_thunk->thunks().back().get(),
791 ir_emitter_context_->llvm_module());
792 AddThunkToThunkSequence(std::move(select_and_scatter_thunk));
793 return ParallelLoopEmitter(loop_body_emitter, source->shape(),
794 launch_dimensions, &b_)
795 .EmitLoop(IrName(select_and_scatter), index_type);
796 }
797
HandleWhile(HloInstruction * xla_while)798 Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while) {
799 HloComputation* condition = xla_while->while_condition();
800 TF_RET_CHECK(ShapeUtil::IsScalar(condition->root_instruction()->shape()) &&
801 condition->root_instruction()->shape().element_type() == PRED)
802 << "While condition computation must return bool";
803 // Build ForThunk for conformant while loops, otherwise build WhileThunk.
804 auto config = xla_while->backend_config<WhileLoopBackendConfig>();
805 if (config.ok() && config.ValueOrDie().has_known_trip_count()) {
806 AddThunkToThunkSequence(
807 BuildForThunk(xla_while, config.ValueOrDie().known_trip_count().n()));
808 } else {
809 AddThunkToThunkSequence(BuildWhileThunk(xla_while));
810 }
811 return Status::OK();
812 }
813
HandleRng(HloInstruction * rng)814 Status IrEmitterUnnested::HandleRng(HloInstruction* rng) {
815 return Unimplemented("Rng should be expanded for GPU.");
816 }
817
HandleRngGetAndUpdateState(HloInstruction * rng_state)818 Status IrEmitterUnnested::HandleRngGetAndUpdateState(
819 HloInstruction* rng_state) {
820 // Emit a kernel to increment the global state for Philox RNG algorithm.
821 AddThunkToThunkSequence(
822 BuildKernelThunk(rng_state, /*implements_whole_instruction=*/true));
823
824 llvm::Value* old_state = llvm_ir::RngGetAndUpdateState(
825 Cast<HloRngGetAndUpdateStateInstruction>(rng_state)->delta(), module_,
826 &b_);
827
828 llvm::Value* output_address =
829 GetIrArray(*rng_state, *rng_state)
830 .EmitArrayElementAddress(
831 llvm_ir::IrArray::Index(
832 /*linear=*/b_.getInt64(0), rng_state->shape(), &b_),
833 &b_, "rng_state_address");
834 output_address = BitCast(
835 output_address, llvm::PointerType::get(
836 old_state->getType(),
837 output_address->getType()->getPointerAddressSpace()));
838 Store(old_state, output_address);
839
840 return Status::OK();
841 }
842
HandleScatter(HloInstruction * scatter)843 Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) {
844 const HloInstruction* operand = scatter->operand(0);
845 const HloInstruction* scatter_indices = scatter->operand(1);
846 const HloInstruction* updates = scatter->operand(2);
847 std::vector<std::unique_ptr<Thunk>> thunks;
848
849 // Copy the operand into the output if it's not the same buffer already.
850 auto operand_buffer = GetAllocationSlice(*operand);
851 auto destination_buffer = GetAllocationSlice(*scatter);
852 if (operand_buffer != destination_buffer) {
853 thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
854 /*source_address=*/operand_buffer,
855 /*destination_buffer=*/destination_buffer,
856 /*mem_size=*/ShapeUtil::ByteSizeOf(operand->shape()),
857 /*hlo_instruction=*/nullptr));
858 }
859
860 thunks.push_back(
861 BuildKernelThunk(scatter,
862 /*implements_whole_instruction=*/thunks.empty()));
863
864 TF_RETURN_IF_ERROR(EmitScatter(
865 thunks.back().get(), scatter,
866 /*scatter_indices_gen=*/
867 [=](const IrArray::Index& index) {
868 return GetIrArray(*scatter_indices, *scatter)
869 .EmitReadArrayElement(index, &b_, "scatter_index");
870 },
871 /*updates_gen=*/
872 [=](const IrArray::Index& index) {
873 return GetIrArray(*updates, *scatter)
874 .EmitReadArrayElement(index, &b_, "update");
875 }));
876
877 // Elide the sequential thunk if there's no copy.
878 if (thunks.size() == 1) {
879 AddThunkToThunkSequence(std::move(thunks[0]));
880 } else {
881 AddThunkToThunkSequence(
882 absl::make_unique<SequentialThunk>(std::move(thunks), scatter));
883 }
884
885 return Status::OK();
886 }
887
EmitScatter(Thunk * thunk,HloInstruction * scatter,const llvm_ir::ElementGenerator & scatter_indices_gen,const llvm_ir::ElementGenerator & updates_gen)888 Status IrEmitterUnnested::EmitScatter(
889 Thunk* thunk, HloInstruction* scatter,
890 const llvm_ir::ElementGenerator& scatter_indices_gen,
891 const llvm_ir::ElementGenerator& updates_gen) {
892 const HloInstruction* operand = scatter->operand(0);
893 const HloInstruction* scatter_indices = scatter->operand(1);
894 const HloInstruction* updates = scatter->operand(2);
895 const ScatterDimensionNumbers& dim_numbers =
896 scatter->scatter_dimension_numbers();
897 CHECK(ShapeUtil::Equal(scatter->shape(), operand->shape()));
898
899 auto loop_body_emitter = [&](const IrArray::Index& index) -> Status {
900 std::vector<llvm::Value*> raw_window_multidim;
901 std::vector<llvm::Value*> input_scatter_multidim;
902 std::vector<int64> raw_window_bounds;
903
904 // Partition the index into window indices and scatter indices.
905 for (int64 i = 0, e = index.size(); i != e; ++i) {
906 // For window indices also remember the window size, this comes in handy
907 // later.
908 if (absl::c_binary_search(dim_numbers.update_window_dims(), i)) {
909 raw_window_multidim.push_back(index[i]);
910 raw_window_bounds.push_back(updates->shape().dimensions(i));
911 } else {
912 input_scatter_multidim.push_back(index[i]);
913 }
914 }
915 DCHECK_EQ(raw_window_multidim.size(),
916 dim_numbers.update_window_dims_size());
917
918 // Apply inserted_window_dims to the window dimensions.
919 int64 raw_window_multidim_idx = 0;
920 std::vector<llvm::Value*> input_window_multidim;
921 std::vector<int64> input_window_bounds;
922 for (int64 i = 0, e = operand->shape().rank(); i != e; ++i) {
923 if (absl::c_binary_search(dim_numbers.inserted_window_dims(), i)) {
924 input_window_bounds.push_back(1); // Trivial dimension.
925 input_window_multidim.push_back(index.GetConstantWithIndexType(0));
926 } else {
927 input_window_bounds.push_back(
928 raw_window_bounds[raw_window_multidim_idx]);
929 input_window_multidim.push_back(
930 raw_window_multidim[raw_window_multidim_idx]);
931 ++raw_window_multidim_idx;
932 }
933 }
934 DCHECK_EQ(input_window_multidim.size(), operand->shape().rank());
935
936 // Insert a 1 dimension at the end if index_vector_dim requests one.
937 Shape scatter_indices_shape = scatter_indices->shape();
938 if (dim_numbers.index_vector_dim() == scatter_indices_shape.rank()) {
939 scatter_indices_shape.add_dimensions(1);
940 scatter_indices_shape.mutable_layout()->add_minor_to_major(
941 dim_numbers.index_vector_dim());
942 }
943
944 // Now load the indices corresponding to the current window from
945 // scatter_indices.
946 std::vector<llvm::Value*> raw_scatter_index_multidim =
947 input_scatter_multidim;
948 raw_scatter_index_multidim.insert(
949 raw_scatter_index_multidim.begin() + dim_numbers.index_vector_dim(),
950 nullptr);
951 llvm::Value* is_in_bounds = b_.getTrue();
952 for (int64 i = 0, e = dim_numbers.scatter_dims_to_operand_dims_size();
953 i != e; ++i) {
954 // Our index is stored along index_vector_dim, insert that into the lookup
955 // index into scatter_indices.
956 raw_scatter_index_multidim[dim_numbers.index_vector_dim()] =
957 index.GetConstantWithIndexType(i);
958 llvm_ir::IrArray::Index raw_scatter_index_index(
959 raw_scatter_index_multidim, scatter_indices_shape, index.GetType());
960
961 int64 operand_dim = dim_numbers.scatter_dims_to_operand_dims(i);
962 TF_ASSIGN_OR_RETURN(
963 llvm::Value* const loaded_scatter_index,
964 scatter_indices_gen(raw_scatter_index_index.SourceIndexOfReshape(
965 scatter_indices_shape, scatter_indices->shape(), &b_)));
966 // And add the index to our window index. This yields the output index.
967 llvm::Value* casted_scatter_index =
968 IntCast(loaded_scatter_index, index.GetType(),
969 /*isSigned=*/true);
970 llvm::Value* dim_offset =
971 Add(input_window_multidim[operand_dim], casted_scatter_index);
972 input_window_multidim[operand_dim] = dim_offset;
973
974 // Also do the bounds check now.
975 int64 max_index = operand->shape().dimensions(operand_dim) -
976 input_window_bounds[operand_dim] + 1;
977 // is_in_bounds = index >= 0 && index < dim_size-window_size+1
978 // --> index u< dim_size-window_size+1
979 is_in_bounds =
980 And(is_in_bounds, ICmpULT(casted_scatter_index,
981 index.GetConstantWithIndexType(max_index)));
982 }
983
984 llvm_ir::LlvmIfData if_window_in_bounds_data = llvm_ir::EmitIfThenElse(
985 is_in_bounds, "scatter.in_bounds", &b_, /*emit_else=*/false);
986 llvm_ir::SetToFirstInsertPoint(if_window_in_bounds_data.true_block, &b_);
987 // All done, now just read from the calculated input from the window, and do
988 // an atomic store to the calculated location in the output.
989 HloInstruction* output_hlo =
990 scatter->IsFused() ? scatter->parent()->FusionInstruction() : scatter;
991 llvm_ir::IrArray::Index input_window_index(
992 input_window_multidim, output_hlo->shape(), index.GetType());
993 llvm::Value* output_address =
994 GetIrArray(*output_hlo, *output_hlo)
995 .EmitArrayElementAddress(input_window_index, &b_);
996 llvm::Value* input_address = Alloca(llvm_ir::PrimitiveTypeToIrType(
997 updates->shape().element_type(), module_));
998 TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, updates_gen(index));
999 Store(input_ir_value, input_address);
1000
1001 if (!scatter->unique_indices()) {
1002 return EmitAtomicOperationForNestedComputation(
1003 *scatter->to_apply(), output_address, input_address);
1004 } else {
1005 return EmitCallToNestedComputation(*scatter->to_apply(),
1006 {output_address, input_address},
1007 output_address);
1008 }
1009 };
1010
1011 // Launch a kernel that reads every element in the updates tensor. We could
1012 // also do one kernel per window instead if bounds checks turn out to be a
1013 // bottleneck.
1014 LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
1015 updates->shape(), ir_emitter_context_->device_description());
1016 UpdateLaunchDimensions(launch_dimensions, thunk,
1017 ir_emitter_context_->llvm_module());
1018
1019 return ParallelLoopEmitter(loop_body_emitter, updates->shape(),
1020 launch_dimensions, &b_)
1021 .EmitLoop(IrName(scatter),
1022 GetIndexTypeForKernel(scatter, launch_dimensions.launch_bound(),
1023 &b_));
1024 }
1025
HandleSelect(HloInstruction * select)1026 Status IrEmitterUnnested::HandleSelect(HloInstruction* select) {
1027 return IrEmitter::HandleSelect(select);
1028 }
1029
HandleSort(HloInstruction * sort)1030 Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
1031 std::vector<std::unique_ptr<Thunk>> thunks;
1032 Shape keys_shape = sort->operand(0)->shape();
1033 int64 dimension_to_sort = sort->dimensions(0);
1034 for (int64 i = 0; i < sort->operand_count(); ++i) {
1035 ShapeIndex shape_index =
1036 sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({});
1037 // We assume that the layout of all involved operands and outputs is the
1038 // same.
1039 TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(keys_shape,
1040 sort->operand(i)->shape()));
1041 TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
1042 keys_shape, ShapeUtil::GetSubshape(sort->shape(), shape_index)));
1043
1044 // If possible, we share buffers. If that is not possible, we need to copy
1045 // the values, because the emitter does the sorting in-place.
1046 auto destination_buffer = GetAllocationSlice(*sort, shape_index);
1047 auto source_address = GetAllocationSlice(*sort->operand(i));
1048 if (destination_buffer != source_address) {
1049 // TODO(b/26783907): Figure out why we never seem to share buffers for
1050 // key/value sort.
1051 thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
1052 /*source_address=*/source_address,
1053 /*destination_buffer=*/destination_buffer,
1054 /*mem_size=*/ShapeUtil::ByteSizeOf(sort->operand(i)->shape()),
1055 nullptr));
1056 }
1057 }
1058
1059 uint64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort);
1060 int64 num_stages = tensorflow::Log2Ceiling(dimension_to_sort_bound);
1061 CHECK_GE(1ULL << num_stages, dimension_to_sort_bound);
1062 CHECK_LT(1ULL << (num_stages - 1), dimension_to_sort_bound);
1063
1064 // Naive C++ code for the outer loops:
1065 //
1066 // for (int64 stage = 0; stage < Log2Ceiling(dimension_to_sort_bound);
1067 // ++stage) {
1068 // int64 first_xor_mask = (1LL << (stage + 1)) - 1;
1069 // SortInPlace(first_xor_mask);
1070 // for (int64 mask = stage - 1; mask >= 0; --mask) {
1071 // int64 later_xor_mask = 1LL << mask;
1072 // SortInPlace(later_xor_mask);
1073 // }
1074 // }
1075 //
1076 // This follows the alternative representation of the algorithm described on
1077 // Wikipedia: https://en.wikipedia.org/wiki/Bitonic_sorter
1078 //
1079 // Each mask specifies how to derive from one position in the array the
1080 // position with which it should be compared (we calculate the xor of the
1081 // position with the mask).
1082 // As an optimization, we can move the 'mask' loop to inside the
1083 // sorting/comparison loop if the comparisons happen within a small block of
1084 // the array. To make this work, we collect all consecutive masks that are
1085 // smaller than our chosen power of 2 tile size, and pass them to SortInPlace.
1086 // Each thread then processes one tile of data.
1087
1088 const uint64 kTileSize = std::min(2048ULL, 1ULL << num_stages);
1089
1090 // If we cannot combine several xor masks together, we don't use tiling, so we
1091 // calculate the standard launch dimensions for the shape. However we only
1092 // need to iterate through ~half of the dimension to sort (rounded up to the
1093 // next highest power of 2), because each iteration compares one pair of
1094 // elements.
1095 Shape standard_iteration_shape = keys_shape;
1096 uint64 standard_num_iterations_in_sort_dim = 1ULL << (num_stages - 1);
1097 standard_iteration_shape.set_dimensions(dimension_to_sort,
1098 standard_num_iterations_in_sort_dim);
1099 LaunchDimensions standard_launch_dimensions = CalculateLaunchDimensions(
1100 standard_iteration_shape, ir_emitter_context_->device_description());
1101
1102 // Calculate the launch dimensions for the case where we use tiling. We split
1103 // the dimension that should be sorted into tiles of size 'kTileSize'. This
1104 // means we first need to round 'dimension_to_sort_bound' up to be a multiple
1105 // of the tile size.
1106 int64 rounded_bound = RoundUpToNearest(dimension_to_sort_bound, kTileSize);
1107 Shape iteration_shape = keys_shape;
1108
1109 // We iterate through the element pairs that should be compared.
1110 uint64 num_iterations_in_sort_dim = rounded_bound / 2;
1111 iteration_shape.set_dimensions(dimension_to_sort, num_iterations_in_sort_dim);
1112 uint64 num_iterations = ShapeUtil::ElementsIn(iteration_shape);
1113
1114 // For correctness reasons we need exactly 'kTileSize' / 2 many threads per
1115 // block. Each thread is responsible for copying exactly two adjacent elements
1116 // into shared memory, and then does a comparison of two possibly different
1117 // elements taken from shared memory.
1118 const uint64 kThreadsPerBlock = kTileSize / 2;
1119
1120 // Check whether we should use any tiling. We might not be able to use it if
1121 // we have not enough threads, or not enough shared memory. Also it does not
1122 // give a speedup if the tile size is < 128.
1123 int64 total_shared_memory_needed = 0;
1124 for (int64 i = 0; i < sort->operand_count(); ++i) {
1125 total_shared_memory_needed +=
1126 kTileSize * ShapeUtil::ByteSizeOfPrimitiveType(
1127 sort->operand(i)->shape().element_type());
1128 }
1129 bool no_tiling =
1130 kTileSize < 128 ||
1131 kThreadsPerBlock >
1132 ir_emitter_context_->device_description().threads_per_block_limit() ||
1133 total_shared_memory_needed >
1134 ir_emitter_context_->device_description().shared_memory_per_block();
1135
1136 uint64 num_blocks = CeilOfRatio(num_iterations, kThreadsPerBlock);
1137 LaunchDimensions tiled_launch_dimensions(num_blocks, kThreadsPerBlock);
1138
1139 auto emit_kernel = [&](absl::Span<const int64> xor_masks) {
1140 thunks.push_back(
1141 BuildKernelThunk(sort, /*implements_whole_instruction=*/false));
1142 LaunchDimensions launch_dimensions = xor_masks.size() > 1
1143 ? tiled_launch_dimensions
1144 : standard_launch_dimensions;
1145 UpdateLaunchDimensions(launch_dimensions, thunks.back().get(),
1146 ir_emitter_context_->llvm_module());
1147 std::vector<IrArray> values_arrays;
1148 values_arrays.reserve(sort->operand_count());
1149 for (int64 i = 0; i < sort->operand_count(); ++i) {
1150 ShapeIndex shape_index =
1151 sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({});
1152 values_arrays.push_back(GetIrArray(*sort, *sort, shape_index));
1153 }
1154 return llvm_ir::EmitSortInPlace(
1155 dimension_to_sort, values_arrays, IrName(sort), xor_masks, &b_,
1156 launch_dimensions,
1157 xor_masks.size() > 1 ? num_iterations_in_sort_dim
1158 : standard_num_iterations_in_sort_dim,
1159 kTileSize,
1160 [&](absl::Span<llvm::Value* const> operands, llvm::Value* output) {
1161 return EmitCallToNestedComputation(*sort->to_apply(), operands,
1162 output);
1163 });
1164 };
1165 std::vector<int64> xor_masks;
1166 for (int64 stage = 0; stage < num_stages; ++stage) {
1167 for (int64 mask = stage; mask >= 0; --mask) {
1168 int64 xor_mask;
1169 if (mask == stage) {
1170 xor_mask = (1LL << (stage + 1)) - 1;
1171 } else {
1172 xor_mask = 1LL << mask;
1173 }
1174 if (xor_mask >= kTileSize || no_tiling) {
1175 if (!xor_masks.empty()) {
1176 TF_RETURN_IF_ERROR(emit_kernel(xor_masks));
1177 xor_masks.clear();
1178 }
1179 TF_RETURN_IF_ERROR(emit_kernel({xor_mask}));
1180 } else {
1181 xor_masks.push_back(xor_mask);
1182 }
1183 }
1184 }
1185 if (!xor_masks.empty()) {
1186 TF_RETURN_IF_ERROR(emit_kernel(xor_masks));
1187 }
1188
1189 AddThunkToThunkSequence(
1190 absl::make_unique<SequentialThunk>(std::move(thunks), sort));
1191 return Status::OK();
1192 }
1193
HandleTupleSelect(HloInstruction * tuple_select)1194 Status IrEmitterUnnested::HandleTupleSelect(HloInstruction* tuple_select) {
1195 AddThunkToThunkSequence(
1196 BuildKernelThunk(tuple_select, /*implements_whole_instruction=*/true));
1197 return IrEmitter::HandleTupleSelect(tuple_select);
1198 }
1199
HandleReplicaId(HloInstruction * hlo)1200 Status IrEmitterUnnested::HandleReplicaId(HloInstruction* hlo) {
1201 AddThunkToThunkSequence(
1202 absl::make_unique<ReplicaIdThunk>(GetAllocationSlice(*hlo), hlo));
1203 return Status::OK();
1204 }
1205
HandleCollectivePermute(HloInstruction * hlo)1206 Status IrEmitterUnnested::HandleCollectivePermute(HloInstruction* hlo) {
1207 AddThunkToThunkSequence(absl::make_unique<CollectivePermuteThunk>(
1208 GetAllocationSlice(*hlo->operand(0)), GetAllocationSlice(*hlo), hlo));
1209 return Status::OK();
1210 }
1211
1212 namespace {
1213
1214
1215 } // namespace
1216
HandleAllReduce(HloInstruction * crs)1217 Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) {
1218 VLOG(2) << "AllReduce; replica count: " << hlo_module_config_.replica_count()
1219 << "; operand count: " << crs->operand_count()
1220 << "; NCCL is enabled: " << NcclAllReduceThunk::NcclIsEnabled();
1221
1222 // Note the replica_count == 1 case is handled via device-to-device copy
1223 // below.
1224 bool should_use_nccl_thunk = hlo_module_config_.replica_count() > 1 &&
1225 NcclAllReduceThunk::CanImplement(crs);
1226
1227 if (should_use_nccl_thunk) {
1228 CHECK(crs->operand(0)->shape().IsArray())
1229 << "Operands to all-reduce must be arrays: " << crs->ToString();
1230 AddThunkToThunkSequence(absl::make_unique<NcclAllReduceThunk>(
1231 /*replica_count=*/hlo_module_config_.replica_count(),
1232 /*elements=*/ShapeUtil::ElementsIn(crs->operand(0)->shape()),
1233 /*source_address=*/GetAllocationSlice(*crs->operand(0)),
1234 /*destination_buffer=*/GetAllocationSlice(*crs), crs));
1235 return Status::OK();
1236 }
1237
1238 if (hlo_module_config_.replica_count() != 1) {
1239 // TODO(b/33011107): Support more AllReduce configurations on GPU.
1240 string message = absl::StrFormat(
1241 "Requested AllReduce not implemented on GPU; replica_count: %d; "
1242 "operand_count: %d; IsCrossReplicaAllReduce: %d; NCCL support: %d",
1243 hlo_module_config_.replica_count(), crs->operand_count(),
1244 crs->IsCrossReplicaAllReduce(), NcclAllReduceThunk::NcclIsEnabled());
1245 if (crs->operand_count() > 0) {
1246 absl::StrAppendFormat(
1247 &message, "; first operand array element-type: %s",
1248 PrimitiveType_Name(crs->operand(0)->shape().element_type()));
1249 }
1250 return Unimplemented("%s", message);
1251 }
1252
1253 // CRS with one operand and one replica is simply the identity function.
1254 // Buffer assignment expects a copy, so that's what we do.
1255 //
1256 // TODO(b/80100934): We would like to eliminate one-replica CRS nodes entirely
1257 // in algebraic-simplifier, but currently on some platforms
1258 // HloModuleConfig::num_replicas changes between when the module is compiled
1259 // and when it's run.
1260 if (crs->operand_count() == 1) {
1261 CHECK(crs->operand(0)->shape().IsArray())
1262 << "Operands to all-reduce must be arrays: " << crs->ToString();
1263 AddThunkToThunkSequence(absl::make_unique<DeviceToDeviceCopyThunk>(
1264 /*source_address=*/GetAllocationSlice(*crs->operand(0)),
1265 /*destination_buffer=*/GetAllocationSlice(*crs),
1266 /*mem_size=*/ShapeUtil::ByteSizeOf(crs->shape()), crs));
1267 return Status::OK();
1268 }
1269
1270 // One-replica CRS with multiple operands produces a tuple of the inputs.
1271 // Again, buffer assignment expects us to copy each.
1272 std::vector<std::unique_ptr<Thunk>> thunks;
1273 std::vector<BufferAllocation::Slice> tuple_element_buffers;
1274 for (int64 i = 0; i < crs->operand_count(); ++i) {
1275 tuple_element_buffers.push_back(ir_emitter_context_->buffer_assignment()
1276 .GetUniqueSlice(crs, {i})
1277 .ValueOrDie());
1278 thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
1279 /*source_address=*/GetAllocationSlice(*crs->operand(i)),
1280 /*destination_buffer=*/tuple_element_buffers.back(),
1281 /*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), nullptr));
1282 }
1283
1284 // Output a tuple of the buffers above.
1285 thunks.push_back(absl::make_unique<TupleThunk>(
1286 tuple_element_buffers, GetAllocationSlice(*crs), nullptr));
1287 AddThunkToThunkSequence(
1288 absl::make_unique<SequentialThunk>(std::move(thunks), crs));
1289 return Status::OK();
1290 }
1291
HandleInfeed(HloInstruction * xla_infeed)1292 Status IrEmitterUnnested::HandleInfeed(HloInstruction* xla_infeed) {
1293 return ThunkEmitter(this).HandleInfeed(xla_infeed);
1294 }
1295
HandleOutfeed(HloInstruction * outfeed)1296 Status IrEmitterUnnested::HandleOutfeed(HloInstruction* outfeed) {
1297 return ThunkEmitter(this).HandleOutfeed(outfeed);
1298 }
1299
HandleAfterAll(HloInstruction * after_all)1300 Status IrEmitterUnnested::HandleAfterAll(HloInstruction* after_all) {
1301 return Status::OK();
1302 }
1303
1304 // Figures out how to access the buffers for all subshapes of hlo's operands and
1305 // for hlo itself (i.e. all the buffers produced by HLO).
1306 //
1307 // Returns a map keyed on the pair {HloInstruction, ShapeIndex}. The value for
1308 // this key is a pair {Slice, ShapeIndex}, where the slice tells you the root
1309 // buffer to look in, and the ShapeIndex describes how to dereference starting
1310 // at that buffer to get to the buffer in question.
1311 //
1312 // For example, if {hlo, {1}} is mapped to {slice, {3, 4}}, then the buffer for
1313 // hlo at ShapeIndex {1} (i.e. the buffer for the second tuple element of hlo)
1314 // is found at slice[3][4]. That is, slice is a void***, which we dereference
1315 // twice -- first at index 3, and then at index 4 -- to get the address of our
1316 // buffer.
1317 //
1318 // This function conservatively assumes that we'll touch all sub-buffers of
1319 // every operand and of the output.
1320 static std::map<std::pair<const HloInstruction*, ShapeIndex>,
1321 std::pair<BufferAllocation::Slice, ShapeIndex>>
GetHloBufferSlices(const HloInstruction * hlo,const BufferAssignment & buffer_assn)1322 GetHloBufferSlices(const HloInstruction* hlo,
1323 const BufferAssignment& buffer_assn) {
1324 std::map<std::pair<const HloInstruction*, ShapeIndex>,
1325 std::pair<BufferAllocation::Slice, ShapeIndex>>
1326 slices;
1327
1328 // Tries to find a slice plus an array of indices i1, ..., iN such that the
1329 // sub-buffer for instr at index can be found at slice[i1]...[iN].
1330 auto find_slice_for = [&](const HloInstruction* instr,
1331 const ShapeIndex& index)
1332 -> optional<std::pair<BufferAllocation::Slice, ShapeIndex>> {
1333 // Simple, common case: Is the buffer for instr known at runtime? If so,
1334 // we're done.
1335 auto slice = buffer_assn.GetUniqueSlice(instr, index);
1336 if (slice.ok()) {
1337 return {{slice.ValueOrDie(), ShapeIndex()}};
1338 }
1339
1340 // If that didn't work, walk up any bitcasts that we might see. These must
1341 // appear before any GTE instructions, because it's illegal to bitcast to a
1342 // tuple type.
1343 const HloInstruction* parent = instr;
1344 while (parent->opcode() == HloOpcode::kBitcast) {
1345 parent = parent->operand(0);
1346
1347 auto slice = buffer_assn.GetUniqueSlice(parent, {});
1348 if (slice.ok()) {
1349 return {{slice.ValueOrDie(), ShapeIndex()}};
1350 }
1351 }
1352
1353 // Check whether instr is a GTE instruction. If it is, see if we can get a
1354 // buffer for its parent, and continue walking up parents until we find a
1355 // defined buffer or we hit something that's not a GTE.
1356 ShapeIndex gte_indices;
1357 while (parent->opcode() == HloOpcode::kGetTupleElement) {
1358 gte_indices.push_front(parent->tuple_index());
1359 parent = parent->operand(0);
1360
1361 auto slice = buffer_assn.GetUniqueSlice(parent, {});
1362 if (slice.ok()) {
1363 return {{slice.ValueOrDie(), gte_indices}};
1364 }
1365 }
1366
1367 // Finally, if we don't know the buffer for instr at index, see if we know
1368 // the buffer for instr at index without its last element. If so, we can
1369 // dynamically find the buffer for instr by dereferencing a pointer in that
1370 // buffer. Continue looking this way until we run out of elements in
1371 // 'index'.
1372 //
1373 // We can almost always get a buffer without resorting to this. The only
1374 // exception is for cases where the relevant sub-buffer is truly unknowable,
1375 // for example the sub-buffer of a tuple-shaped select.
1376 ShapeIndex new_index = index;
1377 while (!new_index.empty()) {
1378 gte_indices.push_front(new_index.back());
1379 new_index.pop_back();
1380 auto slice = buffer_assn.GetUniqueSlice(instr, new_index);
1381 if (slice.ok()) {
1382 return {{slice.ValueOrDie(), gte_indices}};
1383 }
1384 }
1385
1386 return nullopt;
1387 };
1388
1389 // Adds entries for all subshapes of instr to `slices`.
1390 auto add_slices_for = [&](const HloInstruction* instr) {
1391 ShapeUtil::ForEachSubshape(
1392 instr->shape(), [&](const Shape& /*shape*/, const ShapeIndex& index) {
1393 if (slices.count({instr, index})) {
1394 // HLOs can have duplicate operands; don't bother redoing work.
1395 return;
1396 }
1397 auto maybe_slice = find_slice_for(instr, index);
1398 if (maybe_slice.has_value()) {
1399 slices[{instr, index}] = *maybe_slice;
1400 } else {
1401 VLOG(1) << "Couldn't find buffer for " << instr->ToString()
1402 << " at index " << index.ToString();
1403 }
1404 });
1405 };
1406
1407 add_slices_for(hlo);
1408 for (const HloInstruction* operand : hlo->operands()) {
1409 // Conservatively assume we'll need the buffers for all subshapes of the
1410 // operand.
1411 add_slices_for(operand);
1412 }
1413
1414 return slices;
1415 }
1416
BuildKernelThunk(const HloInstruction * inst,bool implements_whole_instruction,int unroll_factor)1417 std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
1418 const HloInstruction* inst, bool implements_whole_instruction,
1419 int unroll_factor) {
1420 const BufferAssignment& buffer_assn =
1421 ir_emitter_context_->buffer_assignment();
1422
1423 std::map<std::pair<const HloInstruction*, ShapeIndex>,
1424 std::pair<BufferAllocation::Slice, ShapeIndex>>
1425 hlo_slices = GetHloBufferSlices(inst, buffer_assn);
1426
1427 // Figure out which buffer allocations need to be passed as arguments to our
1428 // kernel. This is simply all of the allocations referenced in hlo_slices,
1429 // plus the XLA temp buffer (if we have it). We always include the temp
1430 // buffer because even if the kernel itself doesn't use it, a nested
1431 // subcomputation within the kernel (e.g. a kMap's computation) might.
1432 std::unordered_set<const BufferAllocation*> buffers_needed;
1433 for (const auto& kv : hlo_slices) {
1434 buffers_needed.insert(kv.second.first.allocation());
1435 }
1436 absl::optional<const BufferAllocation*> temp_buffer;
1437 for (const BufferAllocation& alloc : buffer_assn.Allocations()) {
1438 if (alloc.IsPreallocatedTempBuffer()) {
1439 if (!temp_buffer.has_value()) {
1440 temp_buffer = &alloc;
1441 } else {
1442 LOG(FATAL) << "Multiple temp buffers found, but only one is allowed!";
1443 }
1444 }
1445 }
1446 if (temp_buffer.has_value()) {
1447 buffers_needed.insert(*temp_buffer);
1448 }
1449
1450 // We'll pass a pointer to each of the elements of `buffers` to our kernel, in
1451 // this order.
1452 std::vector<const BufferAllocation*> non_constant_buffers;
1453 absl::c_copy_if(buffers_needed, std::back_inserter(non_constant_buffers),
1454 [](const BufferAllocation* allocation) {
1455 return !allocation->is_constant();
1456 });
1457
1458 absl::c_sort(non_constant_buffers,
1459 [](const BufferAllocation* a, const BufferAllocation* b) {
1460 return a->index() < b->index();
1461 });
1462
1463 llvm::Function* kernel = BuildKernelPrototype(*inst, non_constant_buffers);
1464
1465 // Build a map from a BufferAllocation to the corresponding argument in our
1466 // kernel.
1467 std::unordered_map<const BufferAllocation*, llvm::Value*> kernel_args;
1468 {
1469 auto arg_it = kernel->arg_begin();
1470 auto buffers_it = non_constant_buffers.begin();
1471 for (; arg_it != kernel->arg_end(); ++arg_it, ++buffers_it) {
1472 kernel_args[*buffers_it] = arg_it;
1473 }
1474 }
1475
1476 // For each buffer our kernel might want to touch, bind it to a value derived
1477 // from our kernel args.
1478 for (const auto& kv : hlo_slices) {
1479 const HloInstruction* instr = kv.first.first;
1480 const ShapeIndex& index = kv.first.second;
1481 const BufferAllocation::Slice& slice = kv.second.first;
1482 const ShapeIndex& gte_index = kv.second.second;
1483
1484 VLOG(3) << "Buffer for " << instr->ToString() << " at " << index.ToString()
1485 << " is found in slice " << slice.ToString() << " at GTE index "
1486 << gte_index.ToString();
1487
1488 llvm::Value* loc;
1489 if (slice.allocation()->is_constant()) {
1490 loc = ir_emitter_context_->llvm_module()->getGlobalVariable(
1491 llvm_ir::ConstantBufferAllocationToGlobalName(*slice.allocation()));
1492 CHECK_NE(loc, nullptr);
1493 } else {
1494 loc = InBoundsGEP(kernel_args.at(slice.allocation()),
1495 {b_.getInt64(slice.offset())});
1496 }
1497
1498 // If gte_index is nonempty, we have to dereference `loc` to get to the
1499 // value we're ultimately interested in.
1500 llvm::Type* int8_double_pointer =
1501 llvm::PointerType::get(b_.getInt8PtrTy(), /*AddressSpace=*/0);
1502 for (int64 idx : gte_index) {
1503 loc = b_.CreatePointerBitCastOrAddrSpaceCast(loc, int8_double_pointer);
1504 loc = Load(InBoundsGEP(loc, {b_.getInt64(idx)}));
1505 }
1506
1507 bindings_.BindHloToIrValue(*instr, loc, index);
1508 }
1509
1510 // Bind the temp buffer so that nested subcomputations can find it if they
1511 // need.
1512 if (temp_buffer.has_value()) {
1513 bindings_.SetTempBufferBase(kernel_args.at(*temp_buffer));
1514 } else {
1515 bindings_.SetTempBufferBase(
1516 llvm::ConstantPointerNull::get(b_.getInt8PtrTy()));
1517 }
1518
1519 return absl::make_unique<KernelThunk>(
1520 non_constant_buffers, std::string(kernel->getName()),
1521 implements_whole_instruction ? inst : nullptr, unroll_factor);
1522 }
1523
BuildInitializerThunk(HloInstruction * hlo,const ShapeIndex & index)1524 StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
1525 HloInstruction* hlo, const ShapeIndex& index) {
1526 bool fused = HloOpcode::kFusion == hlo->opcode();
1527 HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo;
1528 HloInstruction* init_value_operand = [&] {
1529 switch (inst->opcode()) {
1530 case HloOpcode::kSelectAndScatter:
1531 return inst->mutable_operand(2);
1532 case HloOpcode::kReduce:
1533 return inst->mutable_operand(1);
1534 case HloOpcode::kTuple:
1535 CHECK(hlo->IsMultiOutputFusion())
1536 << ": " << hlo->ToString() << " is not a multi-output fusion.";
1537 CHECK(inst->operand(index.back())->opcode() == HloOpcode::kReduce)
1538 << ": Found '" << inst->operand(index.back())->opcode() << "' in "
1539 << inst->ToString() << " but expected 'reduce'.";
1540 // For multi-output fusion look through the tuple.
1541 return inst->mutable_operand(index.back())->mutable_operand(1);
1542 default:
1543 LOG(FATAL) << "Opcode " << inst->opcode()
1544 << " should not need an initializer.";
1545 }
1546 }();
1547
1548 const HloInstruction* init_value = init_value_operand;
1549 if (fused && init_value->opcode() == HloOpcode::kParameter) {
1550 init_value = hlo->operand(init_value->parameter_number());
1551 }
1552
1553 // Initializer thunks don't implement a whole instruction, and we want to
1554 // profile the whole instruction instead of the individual thunks it consists
1555 // of. Therefore we pass nullptr as the HloInstruction* to the thunks we
1556 // generate below.
1557 //
1558 // In the common case, the initializer is a constant. In this case, emit a
1559 // device-memset call if we can. Currently StreamExecutor only supports
1560 // zeroing and 32-bit memsets.
1561 if (init_value->IsConstant()) {
1562 CHECK(ShapeUtil::IsScalar(init_value->shape()));
1563 int64 num_bytes = ShapeUtil::ByteSizeOfElements(init_value->shape());
1564 const auto& literal = init_value->literal();
1565
1566 // Are all the bytes of this scalar equal to 0? If so, we can create a
1567 // MemzeroThunk.
1568 absl::Span<const uint8> literal_bytes(
1569 reinterpret_cast<const uint8*>(literal.untyped_data()), num_bytes);
1570 if (absl::c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) {
1571 return {absl::make_unique<MemzeroThunk>(GetAllocationSlice(*hlo, index),
1572 nullptr)};
1573 }
1574
1575 // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by
1576 // repeating the literal 4 or 2 times, so long as the destination buffer is
1577 // an even multiple of 32 bits long.
1578 const Shape& output_shape = ShapeUtil::GetSubshape(hlo->shape(), index);
1579 if ((num_bytes == 1 || num_bytes == 2) &&
1580 ShapeUtil::ByteSizeOf(output_shape) % 4 == 0) {
1581 uint16 pattern16;
1582 if (num_bytes == 1) {
1583 uint8 b = literal_bytes.front();
1584 pattern16 = uint16{b} | (uint16{b} << 8);
1585 } else {
1586 memcpy(&pattern16, literal_bytes.data(), sizeof(pattern16));
1587 }
1588 uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16);
1589 return {absl::make_unique<Memset32BitValueThunk>(
1590 pattern32, GetAllocationSlice(*hlo, index), nullptr)};
1591 }
1592
1593 // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit
1594 // memset so long as all 32-bit words of the scalar are equal to each other.
1595 if (num_bytes >= 4 && num_bytes % 4 == 0 &&
1596 memcmp(literal_bytes.data(), literal_bytes.data() + 4,
1597 literal_bytes.size() - 4) == 0) {
1598 uint32 word;
1599 memcpy(&word, literal_bytes.data(), sizeof(word));
1600 return {absl::make_unique<Memset32BitValueThunk>(
1601 word, GetAllocationSlice(*hlo, index), nullptr)};
1602 }
1603 }
1604
1605 // Otherwise fall back to our slow initializer code.
1606 std::unique_ptr<KernelThunk> kernel_thunk =
1607 BuildKernelThunk(hlo, /*implements_whole_instruction=*/false);
1608 LaunchDimensions launch_dimensions =
1609 CalculateLaunchDimensions(ShapeUtil::GetSubshape(hlo->shape(), index),
1610 ir_emitter_context_->device_description());
1611 UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
1612 ir_emitter_context_->llvm_module());
1613
1614 if (fused) {
1615 // If init_value was fused into this reduce we have to generate it first.
1616 GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
1617 ir_emitter_context_->llvm_module(),
1618 &b_, GetNestedComputer());
1619
1620 FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(hlo),
1621 &elemental_emitter);
1622 TF_RETURN_IF_ERROR(init_value_operand->Accept(&fused_emitter));
1623 TF_RETURN_IF_ERROR(
1624 ParallelLoopEmitter(fused_emitter.GetGenerator(init_value_operand),
1625 GetIrArray(*hlo, *hlo, index), launch_dimensions,
1626 &b_)
1627 .EmitLoop(IrName(hlo)));
1628 } else {
1629 // In the unfused case the element is already there, just read from it.
1630 TF_RETURN_IF_ERROR(ParallelLoopEmitter(
1631 [=](const IrArray::Index& index) {
1632 return GetIrArray(*init_value, *hlo)
1633 .EmitReadArrayElement(index, &b_);
1634 },
1635 GetIrArray(*hlo, *hlo, index), launch_dimensions,
1636 &b_)
1637 .EmitLoop(IrName(hlo)));
1638 }
1639
1640 // Clean up state left behind by emitting the loop above. (This is normally
1641 // done in IrEmitterUnnested::Postprocess().)
1642 bindings_.UnbindAllLocalIrValues();
1643
1644 // Convert unique_ptr<KernelThunk> to StatusOr<unique_ptr<Thunk>>.
1645 return {std::move(kernel_thunk)};
1646 }
1647
1648 namespace {
1649
1650 // Checks that the buffers corresponding to the given two HLOs share the same
1651 // allocation.
CheckHloBuffersShareAllocation(const HloInstruction * a,const HloInstruction * b,const ShapeIndex & index,const BufferAssignment & buffer_assignment)1652 Status CheckHloBuffersShareAllocation(
1653 const HloInstruction* a, const HloInstruction* b, const ShapeIndex& index,
1654 const BufferAssignment& buffer_assignment) {
1655 const BufferAllocation::Slice slice_a =
1656 buffer_assignment.GetUniqueSlice(a, index).ConsumeValueOrDie();
1657 const BufferAllocation::Slice slice_b =
1658 buffer_assignment.GetUniqueSlice(b, index).ConsumeValueOrDie();
1659 if (slice_a != slice_b) {
1660 return InternalError(
1661 "instruction %s %s does not share allocation with instruction %s %s",
1662 a->ToString(), slice_a.ToString(), b->ToString(), slice_b.ToString());
1663 }
1664 return Status::OK();
1665 }
1666
1667 // Checks that all buffers used during while loop iteration share the same
1668 // buffer allocation. This includes buffers for while result, while init
1669 // operand, condition parameter, body parameter and body result.
1670 // Returns OK on success, error status otherwise.
CheckWhileBuffersShareAllocation(const HloInstruction * xla_while,const BufferAssignment & buffer_assignment)1671 Status CheckWhileBuffersShareAllocation(
1672 const HloInstruction* xla_while,
1673 const BufferAssignment& buffer_assignment) {
1674 return ShapeUtil::ForEachSubshapeWithStatus(
1675 xla_while->shape(),
1676 [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status {
1677 const HloInstruction* condition_parameter =
1678 xla_while->while_condition()->parameter_instruction(0);
1679 const HloComputation* body = xla_while->while_body();
1680 const HloInstruction* body_parameter = body->parameter_instruction(0);
1681 const HloInstruction* body_result = body->root_instruction();
1682 TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
1683 xla_while, xla_while->operand(0), index, buffer_assignment));
1684 TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
1685 xla_while, condition_parameter, index, buffer_assignment));
1686 TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
1687 xla_while, body_parameter, index, buffer_assignment));
1688 TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
1689 xla_while, body_result, index, buffer_assignment));
1690 return Status::OK();
1691 });
1692 }
1693
1694 // Checks that the buffers used in a conditional instruction are shared with the
1695 // operands and result as follows:
1696 // * The result buffer of the conditional should share the allocation with the
1697 // result buffers of each branch computation.
1698 // * The buffer of operand b+1 should share the allocation with the buffer of
1699 // the parameter 0 instruction of the b'th computation.
CheckConditionalBuffersShareAllocation(const HloInstruction * conditional,const BufferAssignment & buffer_assignment)1700 Status CheckConditionalBuffersShareAllocation(
1701 const HloInstruction* conditional,
1702 const BufferAssignment& buffer_assignment) {
1703 TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
1704 conditional->shape(),
1705 [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status {
1706 for (auto branch_computation : conditional->branch_computations()) {
1707 TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
1708 conditional, branch_computation->root_instruction(), index,
1709 buffer_assignment));
1710 }
1711 return Status::OK();
1712 }));
1713 for (int j = 0; j < conditional->branch_count(); ++j) {
1714 TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
1715 conditional->operand(j + 1)->shape(),
1716 [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status {
1717 return CheckHloBuffersShareAllocation(
1718 conditional->operand(j + 1),
1719 conditional->branch_computation(j)->parameter_instruction(0),
1720 index, buffer_assignment);
1721 }));
1722 }
1723 return Status::OK();
1724 }
1725
1726 } // namespace
1727
BuildWhileThunk(const HloInstruction * hlo)1728 std::unique_ptr<Thunk> IrEmitterUnnested::BuildWhileThunk(
1729 const HloInstruction* hlo) {
1730 // Check that all while-related buffers share an allocation.
1731 TF_CHECK_OK(CheckWhileBuffersShareAllocation(
1732 hlo, ir_emitter_context_->buffer_assignment()));
1733
1734 // Generate thunk sequence for while 'condition'.
1735 HloComputation* condition = hlo->while_condition();
1736 IrEmitterUnnested ir_emitter_condition(hlo_module_config_, condition,
1737 ir_emitter_context_);
1738 TF_CHECK_OK(condition->Accept(&ir_emitter_condition));
1739
1740 // Generate thunk sequence for while 'body'.
1741 HloComputation* body = hlo->while_body();
1742 IrEmitterUnnested ir_emitter_body(hlo_module_config_, body,
1743 ir_emitter_context_);
1744 TF_CHECK_OK(body->Accept(&ir_emitter_body));
1745
1746 return absl::make_unique<WhileThunk>(
1747 GetAllocationSlice(*condition->root_instruction()), // cond result
1748 ir_emitter_condition.ConsumeThunkSequence(),
1749 ir_emitter_body.ConsumeThunkSequence(), hlo);
1750 }
1751
BuildForThunk(const HloInstruction * hlo,const int64 loop_limit)1752 std::unique_ptr<Thunk> IrEmitterUnnested::BuildForThunk(
1753 const HloInstruction* hlo, const int64 loop_limit) {
1754 // Check that all while-related buffers share an allocation.
1755 TF_CHECK_OK(CheckWhileBuffersShareAllocation(
1756 hlo, ir_emitter_context_->buffer_assignment()));
1757
1758 // Generate thunk sequence for while 'body' (will be used a For loop body).
1759 HloComputation* body = hlo->while_body();
1760 IrEmitterUnnested ir_emitter_body(hlo_module_config_, body,
1761 ir_emitter_context_);
1762 TF_CHECK_OK(body->Accept(&ir_emitter_body));
1763
1764 return absl::make_unique<ForThunk>(
1765 loop_limit, ir_emitter_body.ConsumeThunkSequence(), hlo);
1766 }
1767
BuildConditionalThunk(const HloInstruction * hlo)1768 std::unique_ptr<Thunk> IrEmitterUnnested::BuildConditionalThunk(
1769 const HloInstruction* hlo) {
1770 // Check that the buffers used in conditional are shared with the operands and
1771 // result appropriately.
1772 TF_CHECK_OK(CheckConditionalBuffersShareAllocation(
1773 hlo, ir_emitter_context_->buffer_assignment()));
1774
1775 std::vector<BufferAllocation::Slice> branch_operands;
1776 std::vector<ThunkSequence> branch_thunks;
1777 for (int j = 0; j < hlo->branch_count(); ++j) {
1778 branch_operands.emplace_back(GetAllocationSlice(*hlo->operand(j + 1)));
1779 HloComputation* branch_computation = hlo->branch_computation(j);
1780 IrEmitterUnnested ir_emitter(hlo_module_config_, branch_computation,
1781 ir_emitter_context_);
1782 TF_CHECK_OK(branch_computation->Accept(&ir_emitter));
1783 branch_thunks.push_back(std::move(*ir_emitter.ConsumeThunkSequence()));
1784 }
1785
1786 return absl::make_unique<ConditionalThunk>(
1787 GetAllocationSlice(*hlo->operand(0)), branch_operands,
1788 std::move(branch_thunks), hlo);
1789 }
1790
EmitTargetElementLoopInThunk(const HloInstruction & hlo,const llvm_ir::ElementGenerator & element_generator,KernelThunk * thunk)1791 Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
1792 const HloInstruction& hlo,
1793 const llvm_ir::ElementGenerator& element_generator, KernelThunk* thunk) {
1794 int unroll_factor = thunk->unroll_factor();
1795 VLOG(3) << bindings_.ToString();
1796
1797 bool multi_output = hlo.shape().IsTuple();
1798
1799 const Shape& element_shape =
1800 multi_output ? ShapeUtil::GetSubshape(hlo.shape(), {0}) : hlo.shape();
1801 VLOG(3) << "EmitTargetElementLoopInThunk "
1802 << ShapeUtil::HumanStringWithLayout(hlo.shape())
1803 << " for unroll_factor " << unroll_factor;
1804 LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
1805 element_shape, ir_emitter_context_->device_description(), unroll_factor);
1806 UpdateLaunchDimensions(launch_dimensions, thunk,
1807 ir_emitter_context_->llvm_module());
1808 if (!multi_output) {
1809 return ParallelLoopEmitter(element_generator, GetIrArray(hlo, hlo),
1810 launch_dimensions, &b_, unroll_factor)
1811 .EmitLoop(
1812 IrName(&hlo),
1813 GetIndexTypeForKernel(&hlo, launch_dimensions.launch_bound(), &b_));
1814 }
1815
1816 // Emit the tuple pointers in one thread. We could do this at any point in
1817 // the kernel, but we do it at the beginning in the hopes of reducing register
1818 // pressure, since we touch threadIdx.x and blockIdx.x at the beginning of the
1819 // kernel *anyway*.
1820 std::vector<IrArray> output_arrays = ConstructIrArrayForOutputs(hlo);
1821 KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] {
1822 llvm_ir::EmitTuple(GetIrArray(hlo, hlo), output_arrays, &b_);
1823 });
1824
1825 // For multioutput fusion, we need to emit each operand and the root.
1826 TF_RETURN_IF_ERROR(
1827 ParallelLoopEmitter(element_generator, output_arrays, launch_dimensions,
1828 &b_, unroll_factor)
1829 .EmitLoop(IrName(&hlo),
1830 GetIndexTypeForKernel(
1831 &hlo, launch_dimensions.launch_bound(), &b_)));
1832
1833 b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator());
1834 return Status::OK();
1835 }
1836
1837 namespace {
1838
1839 // Returns true if the fusion contains any instruction that is likely
1840 // translated to complex LLVM IR, such as loops, and prevent vectorization.
MayPreventVectorization(const HloInstruction & hlo)1841 bool MayPreventVectorization(const HloInstruction& hlo) {
1842 if (hlo.opcode() == HloOpcode::kFusion) {
1843 return absl::c_any_of(hlo.fused_instructions_computation()->instructions(),
1844 [](const HloInstruction* instr) {
1845 switch (instr->opcode()) {
1846 case HloOpcode::kReduce:
1847 case HloOpcode::kReduceWindow:
1848 case HloOpcode::kSort:
1849 case HloOpcode::kDot:
1850 case HloOpcode::kSin:
1851 case HloOpcode::kCos:
1852 case HloOpcode::kPower:
1853 case HloOpcode::kAtan2:
1854 return true;
1855 default:
1856 return false;
1857 }
1858 });
1859 } else if (hlo.IsElementwise()) {
1860 // Unfused elementwise operations are usually memory bound, unroll them.
1861 switch (hlo.opcode()) {
1862 // The following elementwise operation implementations contain branches.
1863 // LLVM vectorizer doesn't work in that case.
1864 // The unrolled code is faster when it isn't vectorized.
1865 case HloOpcode::kSin:
1866 case HloOpcode::kCos:
1867 case HloOpcode::kPower:
1868 case HloOpcode::kAtan2:
1869 return true;
1870 default:
1871 return false;
1872 }
1873 }
1874 return true;
1875 }
1876
1877 } // namespace
1878
EmitTargetElementLoop(const HloInstruction & hlo,const llvm_ir::ElementGenerator & element_generator)1879 Status IrEmitterUnnested::EmitTargetElementLoop(
1880 const HloInstruction& hlo,
1881 const llvm_ir::ElementGenerator& element_generator) {
1882 int unroll_factor = 1;
1883 if (!MayPreventVectorization(hlo)) {
1884 unroll_factor = ComputeMaxUnrollFactor(&hlo);
1885 }
1886
1887 std::unique_ptr<KernelThunk> kernel_thunk = BuildKernelThunk(
1888 &hlo, /*implements_whole_instruction=*/true, unroll_factor);
1889 Status emit_status =
1890 EmitTargetElementLoopInThunk(hlo, element_generator, kernel_thunk.get());
1891 thunk_sequence_->emplace_back(std::move(kernel_thunk));
1892
1893 return emit_status;
1894 }
1895
1896 // Gets the output offset as calculated from thread_id.x (to be applied to the
1897 // offset calculated from block_id and thread_id.y).
GetStartOffsetX(const KernelMappingScheme & mapping_scheme,llvm::Value * thread_id_x,llvm::Type * index_ty,llvm::IRBuilder<> * b)1898 static llvm::Value* GetStartOffsetX(const KernelMappingScheme& mapping_scheme,
1899 llvm::Value* thread_id_x,
1900 llvm::Type* index_ty,
1901 llvm::IRBuilder<>* b) {
1902 if (mapping_scheme.DilatedX()) {
1903 return thread_id_x;
1904 }
1905 int64 x_num_steps =
1906 mapping_scheme.GetTileSizeX() / mapping_scheme.GetNumThreadsX();
1907 return b->CreateMul(thread_id_x,
1908 llvm::ConstantInt::get(index_ty, x_num_steps));
1909 }
1910
1911 // Emits code to process up to
1912 // (tile_size_x/num_threads_x * tile_size_y/num_threads_y) elements in a tile,
1913 // given `emit_elem_function` is the function to emit code to process one
1914 // element, `y` and `x` are the intra-tile coordinates for the first element
1915 // to process, and `index` is the index for the origin of the tile. Information
1916 // about tile_size_x/y and num_threads_x/y are stored in `mapping_scheme`. Emits
1917 // bounds check to ensure that each processed element is within the boundary
1918 // defined by `tile_width` and `tile_height`.
1919 //
1920 // Pseudocode:
1921 //
1922 // for (y_loc = 0; y_loc < tile_height; y_loc += num_threads_y) {
1923 // for (j = 0; j < tile_size_x / num_threads_x; j++) { // unrolled
1924 // if (dilated) {
1925 // x_loc = x + j * num_threads_x;
1926 // } else {
1927 // x_loc = x * (tile_size_x / num_threads_x) + j;
1928 // }
1929 //
1930 // if (x_loc < tile_width) {
1931 // emit_elem_function(y + y_loc, x_loc);
1932 // }
1933 // }
1934 // }
1935 //
EmitTile(const KernelMappingScheme & mapping_scheme,const IrArray::Index & tile_origin_index,const string & loop_name,KernelSupportLibrary * ksl,llvm::IRBuilder<> * b,llvm::Value * y,llvm::Value * x,llvm::Value * tile_height,llvm::Value * tile_width,const IrEmitterUnnested::EmitElementFunction & emit_elem_function)1936 static void EmitTile(
1937 const KernelMappingScheme& mapping_scheme,
1938 const IrArray::Index& tile_origin_index, const string& loop_name,
1939 KernelSupportLibrary* ksl, llvm::IRBuilder<>* b, llvm::Value* y,
1940 llvm::Value* x, llvm::Value* tile_height, llvm::Value* tile_width,
1941 const IrEmitterUnnested::EmitElementFunction& emit_elem_function) {
1942 llvm::Type* index_ty = tile_width->getType();
1943 auto constant = [&](int64 val) {
1944 return llvm::ConstantInt::get(index_ty, val);
1945 };
1946 int64 num_threads_x = mapping_scheme.GetNumThreadsX();
1947 int64 num_threads_y = mapping_scheme.GetNumThreadsY();
1948 int64 tile_size_x = mapping_scheme.GetTileSizeX();
1949
1950 int64 x_num_steps = tile_size_x / num_threads_x;
1951 llvm::Value* start_offset_x = GetStartOffsetX(mapping_scheme, x, index_ty, b);
1952
1953 // Using dilated mapping scheme, each thread steps with a stride of number
1954 // of threads.
1955 // Otherwise, the stride is one, but we multiply each offset by the limit of
1956 // number of steps which can be made.
1957 int64 step_x = mapping_scheme.DilatedX() ? num_threads_x : 1;
1958
1959 IrArray::Index source_idx =
1960 tile_origin_index.AddOffsetToDim(start_offset_x, kDimX, b);
1961
1962 ksl->For(
1963 loop_name + "_y_in_tile",
1964 /*start=*/y,
1965 /*end=*/tile_height,
1966 /*step=*/constant(num_threads_y), [&](llvm::Value* y_loc) {
1967 IrArray::Index source_idx_y =
1968 source_idx.AddOffsetToDim(y_loc, kDimY, b);
1969 for (int64 j = 0; j < x_num_steps; j++) {
1970 llvm::Value* x_loc =
1971 b->CreateAdd(constant(j * step_x), start_offset_x, "x_loc");
1972 IrArray::Index source_idx_x =
1973 source_idx_y.AddOffsetToDim(constant(j * step_x), kDimX, b);
1974 // The if-statement below always evaluates to true for the blocks
1975 // where the entire processed tile fits within the input buffer.
1976 ksl->If(loop_name + "_x_in_tile", b->CreateICmpULT(x_loc, tile_width),
1977 [&] { emit_elem_function(source_idx_x, y_loc, x_loc, j); });
1978 }
1979 });
1980 }
1981
1982 // Emits code to process a tensor element in a tile for the given kCopy HLO that
1983 // performs a 0-2-1 transpose.
1984 //
1985 // index: The index for the first output element in the normalized tensor. The
1986 // normalized tensor is the resulting tensor after collapsing contiguous
1987 // dimensions that play the same role in the transpose.
1988 // y_loc: The y coordinate within a tile.
1989 // x_loc: The x coordinate within a tile.
1990 // mapping_scheme: Kernel mapping scheme specifying the tiling
EmitTileElementForCopy(HloInstruction * hlo,const llvm_ir::IrArray::Index & index,const KernelMappingScheme & mapping_scheme,llvm::Value * y_loc,llvm::Value * x_loc,int64,absl::Span<llvm::Value * const> param_shmem_buffers)1991 void IrEmitterUnnested::EmitTileElementForCopy(
1992 HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
1993 const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
1994 llvm::Value* x_loc, int64 /*x_iter_num*/,
1995 absl::Span<llvm::Value* const> param_shmem_buffers) {
1996 // TODO(jlebar): Add AA metadata to this load.
1997 llvm::Instruction* load_from_shmem_buffer =
1998 Load(GEP(param_shmem_buffers[0], {b_.getInt64(0), x_loc, y_loc}),
1999 "output_element");
2000 llvm_ir::IrArray output_array = GetIrArray(*hlo, *hlo);
2001 Shape output_reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout(
2002 hlo->shape().element_type(), mapping_scheme.GetDimsInElems());
2003 // When the output_reduced_shape is a 0-2-1 transpose of the input shape,
2004 // the 0-2-1 transpose is achieved through EmitWriteArrayElement.
2005 output_array.CastToShape(output_reduced_shape, &b_)
2006 .EmitWriteArrayElement(index, load_from_shmem_buffer, &b_);
2007 }
2008
GetUnnormalizedIndex(const IrArray::Index & normalized_shape_index,const Shape & unnormalized_shape,llvm::IRBuilder<> * b_,const KernelMappingScheme & kernel_mapping_scheme)2009 static IrArray::Index GetUnnormalizedIndex(
2010 const IrArray::Index& normalized_shape_index,
2011 const Shape& unnormalized_shape, llvm::IRBuilder<>* b_,
2012 const KernelMappingScheme& kernel_mapping_scheme) {
2013 DCHECK_EQ(normalized_shape_index.size(), 3);
2014 llvm::Value* linear = normalized_shape_index.Linearize(
2015 kernel_mapping_scheme.GetDimsInElems(), b_);
2016 return IrArray::Index(linear, unnormalized_shape, b_);
2017 }
2018
2019 // Emits code to process a tensor element in a tile for the given kLoop fusion
2020 // HLO containing parameters that are 0-2-1 transpose of its outputs.
2021 //
2022 // index: The index for the first output element in the normalized tensor, that
2023 // is the resulting tensor after collapsing contiguous dimensions that play
2024 // the same role in the transpose.
2025 // kernel_info: Other information to support the kernel code generation.
2026 // y_loc: The y coordinate within a tile.
2027 // x_loc: The x coordinate within a tile.
EmitTileElementForFusion(HloInstruction * hlo,const llvm_ir::IrArray::Index & index,const KernelMappingScheme & mapping_scheme,llvm::Value * y_loc,llvm::Value * x_loc,int64,absl::Span<llvm::Value * const> param_shmem_buffers)2028 void IrEmitterUnnested::EmitTileElementForFusion(
2029 HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
2030 const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
2031 llvm::Value* x_loc, int64 /*x_iter_num*/,
2032 absl::Span<llvm::Value* const> param_shmem_buffers) {
2033 std::vector<IrArray> output_arrays = ConstructIrArrayForOutputs(*hlo);
2034 GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
2035 GetNestedComputer());
2036 FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(hlo),
2037 &elem_emitter, x_loc, y_loc,
2038 param_shmem_buffers);
2039
2040 TF_CHECK_OK(hlo->fused_expression_root()->Accept(&fused_emitter));
2041 IrArray::Index untiled_index = GetUnnormalizedIndex(
2042 index, output_arrays[0].GetShape(), &b_, mapping_scheme);
2043 const llvm_ir::ElementGenerator& output_generator =
2044 fused_emitter.GetRootGenerator();
2045 llvm::Value* output_value = output_generator(untiled_index).ValueOrDie();
2046 if (hlo->IsMultiOutputFusion()) {
2047 DCHECK(output_value->getType()->isStructTy());
2048 DCHECK_EQ(output_value->getType()->getStructNumElements(),
2049 output_arrays.size());
2050 for (int64 i = 0; i < output_arrays.size(); ++i) {
2051 output_arrays[i].EmitWriteArrayElement(
2052 untiled_index, ExtractValue(output_value, i), &b_);
2053 }
2054 } else {
2055 output_arrays[0].EmitWriteArrayElement(untiled_index, output_value, &b_);
2056 }
2057 }
2058
2059 // Gets the number of partial results accumulated by a single thread performing
2060 // reduction.
GetNumberOfPartialResults(const ReductionCodegenInfo & reduction_info)2061 static int GetNumberOfPartialResults(
2062 const ReductionCodegenInfo& reduction_info) {
2063 const KernelMappingScheme& mapping_scheme =
2064 reduction_info.GetKernelMappingScheme();
2065 if (reduction_info.IsRowReduction()) {
2066 return 1;
2067 }
2068 int64 num_partial_results = mapping_scheme.DilatedX() ? 1 : 2;
2069 CHECK_EQ(num_partial_results,
2070 (mapping_scheme.GetTileSizeX() / mapping_scheme.GetNumThreadsX()));
2071 return num_partial_results;
2072 }
2073
EmitPrologueForReduction(HloInstruction * unnested_hlo,ReductionCodegenInfo * reduction_info,absl::Span<HloInstruction * const> reduce_instructions,llvm::Type * index_type)2074 void IrEmitterUnnested::EmitPrologueForReduction(
2075 HloInstruction* unnested_hlo, ReductionCodegenInfo* reduction_info,
2076 absl::Span<HloInstruction* const> reduce_instructions,
2077 llvm::Type* index_type) {
2078 VLOG(10) << "Emit prologue for reduction: " << unnested_hlo->ToString();
2079 GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
2080 ir_emitter_context_->llvm_module(),
2081 &b_, GetNestedComputer());
2082 const HloInstruction* first_reduce = nullptr;
2083 for (int i = 0; i < reduce_instructions.size(); i++) {
2084 HloInstruction* reduce_inst = reduce_instructions[i];
2085 VLOG(10) << "Emit prologue for reduction: " << reduce_inst->ToString();
2086 if (first_reduce == nullptr) {
2087 first_reduce = reduce_inst;
2088 } else {
2089 CHECK(first_reduce->dimensions() == reduce_inst->dimensions());
2090 }
2091
2092 AddressVector* reduction_input_addresses =
2093 reduction_info->GetMutableReductionInputAddresses();
2094 llvm::Type* element_type =
2095 llvm_ir::PrimitiveTypeToIrType(reduce_inst->shape().element_type(),
2096 ir_emitter_context_->llvm_module());
2097 llvm::AllocaInst* reduction_input_address = Alloca(element_type);
2098 reduction_input_addresses->push_back(reduction_input_address);
2099
2100 int num_partial_results = GetNumberOfPartialResults(*reduction_info);
2101 AddressVector* partial_result_addresses =
2102 reduction_info->GetMutablePartialResultAddresses();
2103 llvm::AllocaInst* partial_result_address =
2104 Alloca(element_type, /*ArraySize=*/b_.getInt32(num_partial_results),
2105 "partial_reduction_result." + llvm::Twine(i));
2106 partial_result_addresses->push_back(partial_result_address);
2107
2108 // Initialize the partial result with the initial value of the reduction.
2109 llvm::Value* init_ir_value;
2110 const HloInstruction* init_value = reduce_inst->operand(1);
2111 if (unnested_hlo->opcode() == HloOpcode::kFusion) {
2112 FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(unnested_hlo),
2113 &elemental_emitter);
2114
2115 TF_CHECK_OK(init_value->Accept(&fused_emitter));
2116 init_ir_value =
2117 fused_emitter
2118 .GetGenerator(init_value)(IrArray::Index(b_.getInt32Ty()))
2119 .ValueOrDie();
2120 } else {
2121 init_ir_value =
2122 GetIrArray(*init_value, *unnested_hlo)
2123 .EmitReadArrayElement(IrArray::Index(b_.getInt32Ty()), &b_);
2124 }
2125
2126 for (int i = 0; i < num_partial_results; ++i) {
2127 Store(init_ir_value,
2128 InBoundsGEP(partial_result_address, {b_.getInt32(i)}));
2129 }
2130 }
2131 }
2132
EmitFullWarpShuffleDownLoopForAllReduces(absl::Span<HloComputation * const> reducers,absl::Span<llvm::AllocaInst * const> partial_result_addresses)2133 void IrEmitterUnnested::EmitFullWarpShuffleDownLoopForAllReduces(
2134 absl::Span<HloComputation* const> reducers,
2135 absl::Span<llvm::AllocaInst* const> partial_result_addresses) {
2136 CHECK_EQ(reducers.size(), partial_result_addresses.size());
2137 for (int i = 0; i != reducers.size(); i++) {
2138 EmitFullWarpShuffleDownLoopForReduce(
2139 reducers[i], partial_result_addresses[i]->getType()->getElementType(),
2140 partial_result_addresses[i]);
2141 }
2142 }
2143
EmitFullWarpShuffleDownLoopForReduce(HloComputation * reducer,llvm::Type * element_type,llvm::Value * partial_result_address)2144 void IrEmitterUnnested::EmitFullWarpShuffleDownLoopForReduce(
2145 HloComputation* reducer, llvm::Type* element_type,
2146 llvm::Value* partial_result_address) {
2147 for (int distance = 16; distance >= 1; distance /= 2) {
2148 int bit_width = llvm_ir::GetSizeInBits(element_type);
2149 llvm::Value* result_from_other_lane =
2150 Alloca(element_type, nullptr, "result_from_other_lane");
2151 // Bitcast cannot be applied to aggregate types (even packed ones), so
2152 // we bitcast addresses of load/store to intN* of the same bit-width.
2153 llvm::Type* shuffled_value_type =
2154 element_type->isStructTy() ? b_.getIntNTy(bit_width) : element_type;
2155 auto convert_pointer_for_shuffle = [&](llvm::Value* ptr) {
2156 return b_.CreatePointerBitCastOrAddrSpaceCast(
2157 ptr, shuffled_value_type->getPointerTo());
2158 };
2159 llvm::Value* partial_result =
2160 Load(convert_pointer_for_shuffle(partial_result_address),
2161 "partial_reduction_result");
2162 Store(EmitFullWarpShuffleDown(partial_result, b_.getInt32(distance), &b_),
2163 convert_pointer_for_shuffle(result_from_other_lane));
2164 TF_CHECK_OK(EmitCallToNestedComputation(
2165 *reducer, {partial_result_address, result_from_other_lane},
2166 partial_result_address));
2167 }
2168 }
2169
2170 // Given the IrArray index of a reduction input, returns the linear address of
2171 // the reduction output as if the reduction were going to keep the input shape
2172 // with the dimensions being reduced moved.
GetUntransposedOutputLinearAddress(llvm::IRBuilder<> * b,const llvm_ir::IrArray::Index & index,const ReductionCodegenInfo & reduction_info)2173 static llvm::Value* GetUntransposedOutputLinearAddress(
2174 llvm::IRBuilder<>* b, const llvm_ir::IrArray::Index& index,
2175 const ReductionCodegenInfo& reduction_info) {
2176 const KernelMappingScheme& kernel_mapping_scheme =
2177 reduction_info.GetKernelMappingScheme();
2178 if (reduction_info.IsRowReduction()) {
2179 return index[kDimY];
2180 }
2181 absl::Span<const int64> dims_in_elem = kernel_mapping_scheme.GetDimsInElems();
2182 llvm::Value* x_dim_size = index.GetConstantWithIndexType(dims_in_elem[kDimX]);
2183 llvm::Value* x_block_offset = b->CreateMul(index[kDimZ], x_dim_size);
2184 return b->CreateAdd(x_block_offset, index[kDimX]);
2185 }
2186
EmitEpilogueForReduction(llvm::Type * index_ty,HloInstruction * unnested_hlo,const ReductionCodegenInfo & reduction_info,absl::Span<const HloInstruction * const> reduce_instructions,absl::Span<const ShapeIndex> reduction_output_shape_indices,absl::Span<HloComputation * const> reducers,const TilingKernelInfo & tiling_kernel_info)2187 void IrEmitterUnnested::EmitEpilogueForReduction(
2188 llvm::Type* index_ty, HloInstruction* unnested_hlo,
2189 const ReductionCodegenInfo& reduction_info,
2190 absl::Span<const HloInstruction* const> reduce_instructions,
2191 absl::Span<const ShapeIndex> reduction_output_shape_indices,
2192 absl::Span<HloComputation* const> reducers,
2193 const TilingKernelInfo& tiling_kernel_info) {
2194 const KernelMappingScheme& mapping_scheme =
2195 reduction_info.GetKernelMappingScheme();
2196 auto constant = [&](uint64 c) -> llvm::Constant* {
2197 return llvm::ConstantInt::get(index_ty, c);
2198 };
2199
2200 IrEmitterUnnested::ThreadIdInfo thread_id_info =
2201 EmitThreadIdInfo(mapping_scheme.GetThreadsPerBlock(), index_ty,
2202 mapping_scheme.GetNumThreadsX());
2203 llvm::Value* start_offset_x = GetStartOffsetX(
2204 mapping_scheme, thread_id_info.thread_id_x, index_ty, &b_);
2205
2206 IrArray::Index start_offset =
2207 tiling_kernel_info.tile_origin
2208 .AddOffsetToDim(thread_id_info.thread_id_y, kDimY, &b_)
2209 .AddOffsetToDim(start_offset_x, kDimX, &b_);
2210
2211 int num_reduces = reducers.size();
2212 absl::Span<llvm::AllocaInst* const> partial_result_addresses =
2213 reduction_info.GetPartialResultAddresses();
2214 if (reduction_info.IsRowReduction()) {
2215 EmitFullWarpShuffleDownLoopForAllReduces(reducers,
2216 partial_result_addresses);
2217 llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse(
2218 ICmpEQ(thread_id_info.lane_id, constant(0)), "lane_id_is_zero", &b_);
2219 llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_);
2220 } else {
2221 // Some threads in the block are completely outside of the bound of the
2222 // tensor, so they should not do anything at all.
2223 llvm::Value* has_output = b_.CreateAnd(
2224 b_.CreateICmpULT(start_offset_x,
2225 tiling_kernel_info.output_tile_bounds[kDimX]),
2226 b_.CreateICmpULT(thread_id_info.thread_id_y,
2227 tiling_kernel_info.output_tile_bounds[kDimY]));
2228 llvm_ir::LlvmIfData if_has_output =
2229 llvm_ir::EmitIfThenElse(has_output, "output_inbound", &b_);
2230 llvm_ir::SetToFirstInsertPoint(if_has_output.true_block, &b_);
2231 }
2232
2233 int num_partial_results = GetNumberOfPartialResults(reduction_info);
2234
2235 // Emit an atomic operation that accumulates the partial reduction to the
2236 // output element. For row reduction, this is only for lane 0 due to the
2237 // if-statement emitted above.
2238 for (int i = 0; i != num_reduces; ++i) {
2239 const HloInstruction* reduce_hlo = reduce_instructions[i];
2240 Shape reduction_kept_element_shape = ShapeUtil::FilterDimensions(
2241 [&](int64 dim) {
2242 return !absl::c_linear_search(reduce_hlo->dimensions(), dim);
2243 },
2244 reduce_hlo->operand(0)->shape());
2245 for (int j = 0; j < num_partial_results; ++j) {
2246 llvm::Value* untransposed_output_linear_address =
2247 GetUntransposedOutputLinearAddress(
2248 &b_, start_offset.AddOffsetToDim(constant(j), kDimX, &b_),
2249 reduction_info);
2250
2251 // A reduction is allowed to transpose its output. For example, suppose
2252 // we are reducing the second dimension of f32[10,20,30]{3,2,1}. We are
2253 // allowed to produce as output either f32[10,30]{1,0} (no transpose) or
2254 // f32[10,30]{0,1} (transposing the two output dims).
2255 //
2256 // At this point in the function we have a "partial sum" of input elements
2257 // (stored in partial_result_addresses), and we need to accumulate it into
2258 // the correct output element.
2259 auto output_array = GetIrArray(*unnested_hlo, *unnested_hlo,
2260 reduction_output_shape_indices[i]);
2261 IrArray::Index element_index(
2262 /*linear=*/untransposed_output_linear_address,
2263 reduction_kept_element_shape, &b_);
2264 IrArray::Index output_index(element_index.multidim(),
2265 output_array.GetShape(),
2266 element_index.GetType());
2267 llvm::Value* output_address = output_array.EmitArrayElementAddress(
2268 output_index, &b_, "output_element_address");
2269 TF_CHECK_OK(EmitAtomicOperationForNestedComputation(
2270 *reducers[i], output_address,
2271 InBoundsGEP(partial_result_addresses[i], {constant(j)})));
2272 }
2273 }
2274 }
2275
EmitBlockId()2276 llvm::Value* IrEmitterUnnested::EmitBlockId() {
2277 return gpu::EmitCallToTargetIntrinsic(gpu::TargetIntrinsicID::kBlockIdx, {},
2278 {}, &b_);
2279 }
2280
EmitPrintfWithThreadId(absl::string_view fmt,absl::Span<llvm::Value * const> arguments,absl::optional<int64> thread_id_filter,absl::optional<int64> block_id_filter)2281 void IrEmitterUnnested::EmitPrintfWithThreadId(
2282 absl::string_view fmt, absl::Span<llvm::Value* const> arguments,
2283 absl::optional<int64> thread_id_filter,
2284 absl::optional<int64> block_id_filter) {
2285 llvm::Value* thread_id = EmitThreadId(1024, b_.getInt32Ty());
2286 llvm::Value* block_id = EmitBlockId();
2287 std::vector<llvm::Value*> updated_arguments = {thread_id, block_id};
2288 updated_arguments.insert(updated_arguments.end(), arguments.begin(),
2289 arguments.end());
2290 llvm::Value* constraint = b_.getTrue();
2291 if (thread_id_filter) {
2292 constraint = b_.CreateAnd(
2293 constraint, b_.CreateICmpEQ(thread_id, b_.getInt32(*thread_id_filter)));
2294 }
2295 if (block_id_filter) {
2296 constraint = b_.CreateAnd(
2297 constraint, b_.CreateICmpEQ(block_id, b_.getInt32(*block_id_filter)));
2298 }
2299 KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
2300 ksl.If(constraint, [&] {
2301 ::xla::gpu::EmitPrintf(absl::StrCat("[TID=%d,BID=%d] ", fmt, "\n"),
2302 updated_arguments, &b_);
2303 });
2304 }
2305
EmitTileElementForReduction(HloInstruction * unnested_hlo,const Shape & reduction_operand_shape,absl::Span<HloInstruction * const> output_instructions,const llvm_ir::IrArray::Index & index,const ReductionCodegenInfo & reduction_info,absl::Span<HloComputation * const> reducers,int64 x_iter_num)2306 void IrEmitterUnnested::EmitTileElementForReduction(
2307 HloInstruction* unnested_hlo, const Shape& reduction_operand_shape,
2308 absl::Span<HloInstruction* const> output_instructions,
2309 const llvm_ir::IrArray::Index& index,
2310 const ReductionCodegenInfo& reduction_info,
2311 absl::Span<HloComputation* const> reducers, int64 x_iter_num) {
2312 VLOG(10) << "Emit tile element for reduce " << unnested_hlo->ToString();
2313 bool returns_tuple = output_instructions.size() > 1;
2314 int partial_result_index = reduction_info.IsRowReduction() ? 0 : x_iter_num;
2315
2316 InlinedVector<llvm_ir::ElementGenerator, 1> input_gens;
2317 std::vector<std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
2318 extra_output_gens;
2319 GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
2320 GetNestedComputer());
2321 FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(unnested_hlo),
2322 &elem_emitter);
2323 // Construct the ElementGenerator for each reduction and extra output in the
2324 // the group of output instructions.
2325 if (unnested_hlo->opcode() == HloOpcode::kFusion) {
2326 TF_CHECK_OK(unnested_hlo->fused_expression_root()->Accept(&fused_emitter));
2327
2328 for (int i = 0, e = output_instructions.size(); i != e; ++i) {
2329 const HloInstruction* inst = output_instructions[i];
2330 ShapeIndex idx = returns_tuple ? ShapeIndex({i}) : ShapeIndex({});
2331 if (IsReductionFromOrToContiguousDimensions(*inst)) {
2332 input_gens.push_back(fused_emitter.GetGenerator(inst->operand(0)));
2333 } else {
2334 extra_output_gens.emplace_back(fused_emitter.GetGenerator(inst),
2335 std::move(idx));
2336 }
2337 }
2338 } else {
2339 input_gens.push_back([&](const IrArray::Index& index) {
2340 return GetIrArray(*unnested_hlo->operand(0), *unnested_hlo)
2341 .EmitReadArrayElement(index, &b_);
2342 });
2343 }
2344
2345 IrArray::Index input_index =
2346 GetUnnormalizedIndex(index, reduction_operand_shape, &b_,
2347 reduction_info.GetKernelMappingScheme());
2348 // Clear the linear index field of the IrArray::Index to enable the use of
2349 // GetElementPointer with array types. This enables the vectorization of
2350 // the computation for different partial results. Use this index if
2351 // 'num_partial_results > 1'.
2352 int num_partial_results = GetNumberOfPartialResults(reduction_info);
2353 auto index_without_linear = IrArray::Index(
2354 input_index.multidim(), reduction_operand_shape, input_index.GetType());
2355
2356 // Emit code to generate the input and perform the reduction computation for
2357 // each reduction instruction.
2358 for (int i = 0; i != reducers.size(); ++i) {
2359 llvm::AllocaInst* input_address =
2360 reduction_info.GetReductionInputAddresses()[i];
2361 llvm::AllocaInst* partial_reduction_result_address =
2362 reduction_info.GetPartialResultAddresses()[i];
2363 llvm::Value* const input_ir_value =
2364 input_gens[i](num_partial_results > 1 ? index_without_linear
2365 : input_index)
2366 .ValueOrDie();
2367 Store(input_ir_value, input_address);
2368 llvm::Value* partial_result_address = InBoundsGEP(
2369 partial_reduction_result_address, {b_.getInt32(partial_result_index)});
2370 TF_CHECK_OK(EmitCallToNestedComputation(
2371 *reducers[i], {partial_result_address, input_address},
2372 partial_result_address));
2373 }
2374
2375 // Emit code to generate the output for the non-reduction instructions in the
2376 // fusion, if any.
2377 TF_CHECK_OK(EmitExtraOutputsForReduce(
2378 unnested_hlo, input_index,
2379 /*use_linear_index=*/num_partial_results == 1, extra_output_gens));
2380 }
2381
EmitThreadId(int64 threads_per_block,llvm::Type * index_ty)2382 llvm::Value* IrEmitterUnnested::EmitThreadId(int64 threads_per_block,
2383 llvm::Type* index_ty) {
2384 // Calculate (y, x) coordinates respectively in the 2D view of thread block,
2385 // defined by (num_thread_y, num_thread_x) from thread_id.
2386 llvm::CallInst* thread_id_raw = gpu::EmitCallToTargetIntrinsic(
2387 gpu::TargetIntrinsicID::kThreadIdx, {}, {}, &b_);
2388 llvm_ir::AddRangeMetadata(0, threads_per_block, thread_id_raw);
2389 return b_.CreateIntCast(thread_id_raw, index_ty,
2390 /*isSigned=*/true, "thread.id.x");
2391 }
2392
EmitThreadIdInfo(int64 threads_per_block,llvm::Type * index_ty,int64 num_threads_x)2393 IrEmitterUnnested::ThreadIdInfo IrEmitterUnnested::EmitThreadIdInfo(
2394 int64 threads_per_block, llvm::Type* index_ty, int64 num_threads_x) {
2395 auto constant = [&](uint64 c) -> llvm::Constant* {
2396 return llvm::ConstantInt::get(index_ty, c);
2397 };
2398 llvm::Value* thread_id = EmitThreadId(threads_per_block, index_ty);
2399 llvm::Value* num_threads_x_v = constant(num_threads_x);
2400 return {
2401 /*thread_id=*/thread_id,
2402 /*thread_id_x=*/b_.CreateURem(thread_id, num_threads_x_v, "thread_id.x"),
2403 /*thread_id_y=*/b_.CreateUDiv(thread_id, num_threads_x_v, "thread_id.y"),
2404 /*lane_id=*/b_.CreateURem(thread_id, constant(kWarpSize), "lane_id")};
2405 }
2406
EmitTilingKernel(const KernelMappingScheme & mapping_scheme,llvm::Type * index_ty,const TileElementGenerator & tile_element_generator)2407 IrEmitterUnnested::TilingKernelInfo IrEmitterUnnested::EmitTilingKernel(
2408 const KernelMappingScheme& mapping_scheme, llvm::Type* index_ty,
2409 const TileElementGenerator& tile_element_generator) {
2410 absl::Span<const int64> dims_in_elems = mapping_scheme.GetDimsInElems();
2411 std::vector<int64> dims_in_blocks = {
2412 CeilOfRatio(dims_in_elems[0], mapping_scheme.GetTileSizeZ()),
2413 CeilOfRatio(dims_in_elems[1], mapping_scheme.GetTileSizeY()),
2414 CeilOfRatio(dims_in_elems[2], mapping_scheme.GetTileSizeX())};
2415 auto constant = [&](uint64 c) -> llvm::Constant* {
2416 return llvm::ConstantInt::get(index_ty, c);
2417 };
2418
2419 IrEmitterUnnested::ThreadIdInfo thread_id_info =
2420 EmitThreadIdInfo(mapping_scheme.GetThreadsPerBlock(), index_ty,
2421 mapping_scheme.GetNumThreadsX());
2422
2423 KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
2424
2425 const IrArray::Index block_coords = [&] {
2426 llvm::Value* block_id = EmitBlockId();
2427 llvm_ir::AddRangeMetadata(0, mapping_scheme.GetNumberOfBlocks(),
2428 llvm::cast<llvm::Instruction>(block_id));
2429 llvm::Value* linear_block_id =
2430 b_.CreateIntCast(block_id, index_ty, /*isSigned=*/true, "block.id.x");
2431 IrArray::Index starting_block(linear_block_id,
2432 ShapeUtil::MakeShapeWithDescendingLayout(
2433 PRED /*arbitrary*/, dims_in_blocks),
2434 &b_);
2435
2436 std::vector<llvm::Value*> multidim = {
2437 b_.CreateMul(starting_block[0], constant(mapping_scheme.GetTileSizeZ()),
2438 "block_origin.z"),
2439 starting_block[1], starting_block[2]};
2440 return IrArray::Index(multidim, dims_in_blocks, index_ty);
2441 }();
2442
2443 std::array<llvm::Value*, 3> output_tile_bounds;
2444 for (int i = kDimY; i < kDimTot; ++i) {
2445 int64 tile_size_for_dim = mapping_scheme.GetTileSizeFor(i);
2446 // Only last row or column may not have full size.
2447 llvm::Value* is_last =
2448 b_.CreateICmpEQ(block_coords[i], constant(dims_in_blocks[i] - 1));
2449 int64 partial_row =
2450 dims_in_elems[i] - (dims_in_blocks[i] - 1) * tile_size_for_dim;
2451 output_tile_bounds[i] =
2452 b_.CreateSelect(is_last, constant(partial_row),
2453 constant(tile_size_for_dim), "tile_bound");
2454 }
2455
2456 IrArray::Index tile_origin = [&] {
2457 std::vector<llvm::Value*> elem_multi_index = block_coords.multidim();
2458 llvm::Type* index_ty = block_coords.GetType();
2459 for (int i = kDimY; i < kDimTot; ++i) {
2460 elem_multi_index[i] = b_.CreateMul(
2461 block_coords[i],
2462 llvm::ConstantInt::get(index_ty, mapping_scheme.GetTileSizeFor(i)),
2463 "tile_origin." + std::to_string(i));
2464 }
2465 return IrArray::Index(elem_multi_index, mapping_scheme.GetDimsInElems(),
2466 index_ty);
2467 }();
2468
2469 auto emit_tile = [&](const IrArray::Index& tile) {
2470 tile_element_generator(thread_id_info.thread_id_y,
2471 thread_id_info.thread_id_x, tile, "output",
2472 output_tile_bounds[1], output_tile_bounds[2], &ksl);
2473 };
2474
2475 if (mapping_scheme.GetTileSizeZ() == 1) {
2476 emit_tile(tile_origin);
2477 } else {
2478 llvm::Value* starting_tile_index_for_dim = tile_origin[kDimZ];
2479 llvm::Value* block_size_for_dim = constant(mapping_scheme.GetTileSizeZ());
2480 llvm::Value* block_id_for_dim =
2481 b_.CreateUDiv(starting_tile_index_for_dim, block_size_for_dim);
2482 llvm::Value* last_block_for_dim = constant(dims_in_blocks[kDimZ] - 1);
2483 llvm::Value* last_block_size_for_dim =
2484 constant(dims_in_elems[kDimZ] -
2485 (dims_in_blocks[kDimZ] - 1) * mapping_scheme.GetTileSizeZ());
2486
2487 llvm::Value* num_tiles_in_block =
2488 b_.CreateSelect(b_.CreateICmpEQ(last_block_for_dim, block_id_for_dim),
2489 last_block_size_for_dim, block_size_for_dim);
2490 ksl.For("loop_z",
2491 /*start=*/constant(0),
2492 /*end=*/num_tiles_in_block,
2493 /*step=*/1, [&](llvm::Value* block_dim_induction_var) {
2494 IrArray::Index tile_index = tile_origin.AddOffsetToDim(
2495 block_dim_induction_var, kDimZ, &b_);
2496 emit_tile(tile_index);
2497 });
2498 }
2499 return {output_tile_bounds, tile_origin};
2500 }
2501
EmitSyncThreads()2502 llvm::CallInst* IrEmitterUnnested::EmitSyncThreads() {
2503 return EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, &b_);
2504 }
2505
2506 // Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose
2507 // algorithm to improve the memory access patterns for the input parameters
2508 // with a shape that is a 0-2-1 transpose of the output tensor shape. The caller
2509 // is responsible for making sure that it is safe to apply the shared memory
2510 // transpose on the input parameters.
2511 //
2512 //
2513 // For the purpose of tiling, the output tensors have a logical shape of three
2514 // components 0-2-1 while the relevant input parameters have a logical shape
2515 // of three components 0-1-2 in the order major to minor. The x- and y-
2516 // dimensions of the tensors are tiled in square tiles with an edge length
2517 // `kTileSize`. Each thread block of `kTileSize` x `kNumRows` threads
2518 // transposes one tile: each thread copies kTileSize/kNumRows elements from
2519 // the input to a shared memory tile, then the otherwise "regular HLO kernel"
2520 // reads from the shared memory instead of the original input.
2521 //
2522 // This is similar to the following CUDA algorithm in TensorFlow:
2523 // https://goo.gl/MStRV6.
2524 //
2525 // `kTileSize` should usually be same as warp size. We currently choose 32 for
2526 // `kTileSize` and 4 for `kNumRows`. The CUDA algorithm uses 8 for `kNumRows`.
2527 //
2528 // TODO(b/33320379): Here each block transposes 1 tile. It may be more
2529 // efficient to launch fewer blocks so each transposes many tiles.
EmitHlo021Tile(HloInstruction * hlo,Thunk * kernel_thunk,absl::Span<const int64> reduced_output_dims,absl::Span<const int64> tiled_param_ids)2530 void IrEmitterUnnested::EmitHlo021Tile(
2531 HloInstruction* hlo, Thunk* kernel_thunk,
2532 absl::Span<const int64> reduced_output_dims,
2533 absl::Span<const int64> tiled_param_ids) {
2534 constexpr int kNumRows = 4;
2535 KernelMappingScheme mapping_scheme(reduced_output_dims,
2536 /*tile_sizes=*/{1, kWarpSize, kWarpSize},
2537 /*num_threads_y=*/kNumRows,
2538 /*num_threads_x=*/kWarpSize,
2539 /*is_dilated_x=*/false);
2540 LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(),
2541 mapping_scheme.GetThreadsPerBlock());
2542 llvm::Type* index_type =
2543 GetIndexTypeForKernel(hlo, launch_dimensions.launch_bound(), &b_);
2544 std::vector<IrArray> param_arrays;
2545
2546 // For each tiled parameter, cast its input IrArray to the corresponding
2547 // reduced shape and keep the reduced shape live during IR emission.
2548 std::vector<IrArray> param_in_reduced_shape_arrays;
2549 std::vector<llvm::Value*> param_shmem_buffers(hlo->operand_count(), nullptr);
2550
2551 auto get_shared_memory_buffer = [&](llvm::Type* elem_ty,
2552 absl::string_view buffer_name) {
2553 // For Nvidia GPUs, the warp size is 32 threads and the shared memory bank
2554 // is organized into 32-way. We usually use the warp size or a multiplier or
2555 // a the warp size as the size for tiling. This may cause all elements in
2556 // the same column of a tile use the same memory bank and therefore shared
2557 // memory bank conflicts. Adding 1 to the minor dimension of the shared
2558 // memory buffer can reduce such shared memory bank conflicts.
2559 llvm::Type* buffer_type = llvm::ArrayType::get(
2560 llvm::ArrayType::get(elem_ty, mapping_scheme.GetTileSizeX() + 1),
2561 mapping_scheme.GetTileSizeY());
2562 return llvm_ir::AllocateSharedMemoryTile(b_.GetInsertBlock()->getModule(),
2563 buffer_type, buffer_name);
2564 };
2565
2566 for (int64 id = 0; id < hlo->operand_count(); id++) {
2567 const HloInstruction* param = hlo->operand(id);
2568 param_arrays.push_back(GetIrArray(*param, *hlo));
2569
2570 if (absl::c_linear_search(tiled_param_ids, id)) {
2571 param_shmem_buffers[id] =
2572 get_shared_memory_buffer(llvm_ir::PrimitiveTypeToIrType(
2573 param->shape().element_type(), module_),
2574 IrName(hlo, StrCat("tile", id)));
2575 VLOG(3) << "Added shmem buffer for parameter " << id << ": "
2576 << llvm_ir::DumpToString(*param_shmem_buffers[id]);
2577 Shape reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout(
2578 param->shape().element_type(),
2579 Permute({0, 2, 1}, reduced_output_dims));
2580 param_in_reduced_shape_arrays.push_back(
2581 param_arrays[id].CastToShape(reduced_shape, &b_));
2582 } else {
2583 param_in_reduced_shape_arrays.push_back(IrArray());
2584 }
2585 }
2586
2587 EmitElementFunction element_generator =
2588 [&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc,
2589 llvm::Value* x_loc, int64 x_iter_num) {
2590 if (hlo->opcode() == HloOpcode::kCopy) {
2591 EmitTileElementForCopy(hlo, index, mapping_scheme, y_loc, x_loc,
2592 x_iter_num, param_shmem_buffers);
2593 } else {
2594 CHECK_EQ(hlo->opcode(), HloOpcode::kFusion);
2595 EmitTileElementForFusion(hlo, index, mapping_scheme, y_loc, x_loc,
2596 x_iter_num, param_shmem_buffers);
2597 }
2598 };
2599
2600 TileElementGenerator tile_generator =
2601 [&](llvm::Value* y, llvm::Value* x, const IrArray::Index& index,
2602 const string& loop_name, llvm::Value* tile_height,
2603 llvm::Value* tile_width, KernelSupportLibrary* ksl) {
2604 // If shared memory transpose is needed, wait for all threads to reach
2605 // this point, lest we copy a value from tile to output before the other
2606 // thread copies it from input to tile. This is `__syncthreads` in CUDA.
2607 if (!tiled_param_ids.empty()) {
2608 // Calculate the input tile origin from the output tile origin.
2609 const IrArray::Index input_tile_origin(
2610 Permute({0, 2, 1}, index.multidim()),
2611 Permute({0, 2, 1}, index.dims()), index.GetType());
2612
2613 // Copy input parameter values to shared memory buffers:
2614 // tile[y, x] = input[index]
2615 // Note that tile_width and tile_height are flipped here because we
2616 // are reading a transposed tile.
2617 EmitTile(mapping_scheme, input_tile_origin, "input", ksl, &b_, y, x,
2618 tile_width, tile_height,
2619 [&](const IrArray::Index& index, llvm::Value* y_loc,
2620 llvm::Value* x_loc, int64 /*x_iter_num*/) {
2621 for (int64 id : tiled_param_ids) {
2622 IrArray& input_in_logical_shape =
2623 param_in_reduced_shape_arrays[id];
2624
2625 llvm::Value* shmem_buffer = param_shmem_buffers[id];
2626 llvm::Value* zero =
2627 llvm::ConstantInt::get(index_type, 0);
2628 // TODO(jlebar): Add AA metadata to this store. Tile
2629 // buffers are global variables, so LLVM can't infer much
2630 // about it.
2631 Store(input_in_logical_shape.EmitReadArrayElement(
2632 index, &b_, "input_element"),
2633 GEP(shmem_buffer, {zero, y_loc, x_loc}));
2634 }
2635 });
2636
2637 // Wait for all threads to reach this point using `__syncthreads` in
2638 // CUDA.
2639 EmitSyncThreads();
2640 }
2641
2642 EmitTile(mapping_scheme, index, loop_name, ksl, &b_, y, x, tile_height,
2643 tile_width, element_generator);
2644 bool block_contains_multi_tiles = mapping_scheme.GetTileSizeZ() > 1;
2645
2646 // If a tile block contains multiple tiles and shared memory buffers are
2647 // used, we need to wait for all threads to finish using the shared
2648 // memory buffer for the current tile before we move on to process the
2649 // next tile and overwrite the shared memory buffers.
2650 if (block_contains_multi_tiles && !tiled_param_ids.empty()) {
2651 EmitSyncThreads();
2652 }
2653 };
2654
2655 // For multioutput fusion, one thread needs to output a tuple
2656 // with pointers to all the individual outputs. We could do this
2657 // at any point in the kernel, but we do it at the beginning in
2658 // the hopes of reducing register pressure, since we touch
2659 // threadIdx.x and blockIdx.x at the beginning of the kernel
2660 // *anyway*.
2661 if (hlo->IsMultiOutputFusion()) {
2662 KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] {
2663 llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo),
2664 ConstructIrArrayForOutputs(*hlo), &b_);
2665 });
2666 }
2667
2668 EmitTilingKernel(mapping_scheme, index_type, tile_generator);
2669 UpdateLaunchDimensions(launch_dimensions, kernel_thunk,
2670 ir_emitter_context_->llvm_module());
2671 }
2672
2673 namespace {
2674 // A recursive function to inspect the users of a parameter to determine
2675 // whether it's safe for a parameter to participate in a shared-memory
2676 // transpose.
2677 //
2678 // Consider a fusion parameter P for which we might want to use a shmem
2679 // transpose. If we do, we use a GPU thread block to preload a tile of P with
2680 // indices [z, y..y+31, x..x+31] to compute an output tile with the same indices
2681 // cooperatively, where z, y, x are the indices for the normalized input/output
2682 // tensor (see the document for FindTranspose021 for the definition of
2683 // normalized tensor for 0-2-1 transpose). This shmem transpose implementation
2684 // requires that the computation of the output tile only read elements within
2685 // the preload tile. If this is not true, we can't use a shmem transpose for P.
2686 //
2687 // If the computation of output element [z, y, x] only requires the element of
2688 // P with the same indices, the shmem transpose implementation can be applied
2689 // to P safely. This is a sufficient but not necessary condition. We check all
2690 // the transitive users of P to see if we can find a user that may cause an
2691 // exception to the situation. If such a user is not found, we conclude that P
2692 // is safe for shmem transpose.
2693 //
2694 // This is trivially true for elementwise operations and some "data-movement"
2695 // ops like kTuple. However, it's not true for operations that can change the
2696 // dimensions of the inputs (e.g. pad, slice) and bitcast operation.
2697 // For example:
2698 //
2699 // fused_computation {
2700 // param_0 = f32[64,64]{1,0} parameter(0)
2701 // ROOT bitcast = f32[64,64]{0,1} bitcast(param_0)
2702 // }
2703 // The output element at logical address [0, 63] depends on the input element
2704 // at logical address [63, 0], which would not be within the shared-memory
2705 // block.
2706 //
2707 // TODO(bixia): In order to extend this for kInput fusion, that is reduction
2708 // with transpose, we only need to end the use-chain checking with the input of
2709 // a reduce operations. In this case, the above description on "output" apply
2710 // to the result of such a use-chain, which provides the input to the reduce
2711 // operation.
IsInstructionSafeForShmemTranspose(const HloInstruction * hlo)2712 bool IsInstructionSafeForShmemTranspose(const HloInstruction* hlo) {
2713 if (hlo->IsElementwise()) {
2714 return absl::c_all_of(hlo->users(), [&](const HloInstruction* user) {
2715 return IsInstructionSafeForShmemTranspose(user);
2716 });
2717 }
2718
2719 switch (hlo->opcode()) {
2720 // Non-elementwise instructions that don't cause the shmem transpose
2721 // to be unsafe, including the instructions that don't currently fuse.
2722 case HloOpcode::kGetDimensionSize:
2723 // The result of the operation doesn't rely on the content of the
2724 // tensor. As such, there is no need to further inspect its users.
2725 return true;
2726 case HloOpcode::kGetTupleElement:
2727 case HloOpcode::kMap:
2728 case HloOpcode::kParameter:
2729 case HloOpcode::kTuple:
2730 case HloOpcode::kTupleSelect:
2731 return absl::c_all_of(hlo->users(), [&](const HloInstruction* user) {
2732 return IsInstructionSafeForShmemTranspose(user);
2733 });
2734
2735 default:
2736 return false;
2737 }
2738 }
2739
2740 // Given a group of input parameters that are 0-2-1 transpose of the outputs of
2741 // a fusion kernel, returns the input parameters that are safe for the shared
2742 // memory transpose implementation.
2743 //
2744 // When a tile based shared memory transpose is used to implement an input with
2745 // 0-2-1 transpose, we preload a tile of the input elements
2746 // [z, y..y+31, x..x+31] to compute the output tile elements of the same
2747 // indices. Preloading the input tile this way is only safe when the computation
2748 // of the output tile elements do not need any input element outside the
2749 // preloaded tile. We inspect all the transitive users of the input parameter
2750 // up to the fusion root instruction to see if we can find any instruction
2751 // that can make preloading the input tile unsafe.
FilterInputsForShmemTranspose(const HloInstruction * fusion,std::vector<int64> input_ids)2752 std::vector<int64> FilterInputsForShmemTranspose(const HloInstruction* fusion,
2753 std::vector<int64> input_ids) {
2754 std::vector<int64> filtered_input_ids;
2755 for (int64 i = 0; i < input_ids.size(); ++i) {
2756 const HloInstruction* input = fusion->fused_parameter(input_ids[i]);
2757 if (IsInstructionSafeForShmemTranspose(input)) {
2758 filtered_input_ids.push_back(input_ids[i]);
2759 } else {
2760 VLOG(10) << "Input not safe for shmem transpose " << input->ToString();
2761 }
2762 }
2763 return filtered_input_ids;
2764 }
2765
2766 } // namespace
2767
CheckAndEmitHloWithTile021(HloInstruction * hlo)2768 bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) {
2769 HloOpcode opcode = hlo->opcode();
2770
2771 CHECK(hlo->IsLoopFusion() || opcode == HloOpcode::kCopy);
2772
2773 const Shape& output_shape = hlo->IsMultiOutputFusion()
2774 ? ShapeUtil::GetSubshape(hlo->shape(), {0})
2775 : hlo->shape();
2776
2777 // If the output_shape is reduced to 021 shape, find all the parameters of
2778 // the HLO that are in the corresponding 012 shape.
2779 std::vector<int64> params_012;
2780 optional<std::vector<int64>> reduced_dims_021;
2781 for (int64 operand_idx = 0; operand_idx < hlo->operand_count();
2782 ++operand_idx) {
2783 HloInstruction* operand = hlo->mutable_operand(operand_idx);
2784 auto find_transpose_result =
2785 ShapeUtil::FindTranspose021(operand->shape(), output_shape);
2786 if (!find_transpose_result.has_value()) {
2787 continue;
2788 }
2789 const std::vector<int64>& curr_reduced_dims_021 = *find_transpose_result;
2790 if (!reduced_dims_021.has_value()) {
2791 reduced_dims_021 = curr_reduced_dims_021;
2792 }
2793 if (!absl::c_equal(*reduced_dims_021, curr_reduced_dims_021)) {
2794 // There is more than one possible transpose. Instead of picking one
2795 // transpose, we simply give up here.
2796 return false;
2797 }
2798 params_012.push_back(operand_idx);
2799 }
2800
2801 if (!reduced_dims_021.has_value()) {
2802 return false;
2803 }
2804
2805 if ((*reduced_dims_021)[1] < kMinDimensionToTransposeTiled ||
2806 (*reduced_dims_021)[2] < kMinDimensionToTransposeTiled) {
2807 return false;
2808 }
2809
2810 if (opcode == HloOpcode::kFusion) {
2811 params_012 = FilterInputsForShmemTranspose(hlo, params_012);
2812 if (params_012.empty()) {
2813 return false;
2814 }
2815 }
2816
2817 // Each of our shared memory tiles has 32*33 elements (so ~4kb, if the
2818 // elements are of size 4 bytes), and CUDA has an architectural limit of
2819 // 48kb shared memory per SM. (This is increased to 96kb in Volta, but we
2820 // don't use this, in part because it eats into our L1 cache space.)
2821 //
2822 // For correctness we need to ensure that we don't make more than 48kb worth
2823 // of shmem tiles per block. And for performance, we'd probably like to use
2824 // significantly less, so that we can fit more than one block at a time on a
2825 // gpu core.
2826 //
2827 // We say without benchmarks that we want at least 3 threads/block,
2828 // corresponding to 3 shmem tiles if the elements are 32 bits wide. We
2829 // choose which params get the shmem transpose treatment arbitrarily; it's
2830 // not clear if there's a Right Choice.
2831 //
2832 // This is only sound if tiled transposes are the only place where we use
2833 // shared memory in fusions. If in the future other fusible ops use shared
2834 // memory, we'll have to adjust this heuristic.
2835 constexpr int kMinBlocksPerCore = 3;
2836 constexpr int64 kShmemPerCore = 48 * 1024;
2837 int64 shmem_used = 0;
2838 for (int64 i = 0; i < params_012.size(); ++i) {
2839 const HloInstruction* operand = hlo->operand(params_012[i]);
2840 shmem_used +=
2841 32 * 33 *
2842 ShapeUtil::ByteSizeOfPrimitiveType(operand->shape().element_type());
2843
2844 if (kMinBlocksPerCore * shmem_used > kShmemPerCore) {
2845 // Erase this element and everything after it from params_012.
2846 params_012.resize(i);
2847 break;
2848 }
2849 }
2850
2851 if (params_012.empty()) {
2852 return false;
2853 }
2854
2855 VLOG(3) << "EmitHlo021Tile Emitting hlo tile 0-2-1" << hlo->ToString();
2856 std::unique_ptr<KernelThunk> kernel_thunk =
2857 BuildKernelThunk(hlo, /*implements_whole_instruction=*/true);
2858 EmitHlo021Tile(hlo, kernel_thunk.get(), *reduced_dims_021, params_012);
2859 AddThunkToThunkSequence(std::move(kernel_thunk));
2860 return true;
2861 }
2862
2863 namespace {
2864
2865 // Returns true if all the transitive users of hlo before hitting users in
2866 // use_chain_endings are elementwise operations.
AreUsersElementwise(const HloInstruction * hlo,const ConstHloInstructionSet & use_chain_endings)2867 bool AreUsersElementwise(const HloInstruction* hlo,
2868 const ConstHloInstructionSet& use_chain_endings) {
2869 return absl::c_all_of(hlo->users(), [&](const HloInstruction* user) {
2870 return use_chain_endings.count(user) ||
2871 (user->IsElementwise() &&
2872 AreUsersElementwise(user, use_chain_endings));
2873 });
2874 }
2875
2876 // Returns the number of fusion inputs that have the same dimension as the
2877 // given shape, and involve in only elementwise operations.
NumInputsInvolveInOnlyElementwiseOps(const HloInstruction * unnested_hlo,const Shape & op_shape,const ConstHloInstructionSet & use_chain_endings)2878 int64 NumInputsInvolveInOnlyElementwiseOps(
2879 const HloInstruction* unnested_hlo, const Shape& op_shape,
2880 const ConstHloInstructionSet& use_chain_endings) {
2881 return absl::c_count_if(
2882 unnested_hlo->fused_parameters(), [&](const HloInstruction* parameter) {
2883 const Shape& parameter_shape = parameter->shape();
2884 return ShapeUtil::SameDimensions(op_shape, parameter_shape) &&
2885 AreUsersElementwise(parameter, use_chain_endings);
2886 });
2887 }
2888
2889 // Returns the number of fusion inputs that have more elements than the given
2890 // shape.
NumInputsWithMoreElementsThan(const HloInstruction * unnested_hlo,const Shape & shape)2891 int64 NumInputsWithMoreElementsThan(const HloInstruction* unnested_hlo,
2892 const Shape& shape) {
2893 int64 num_elements = ShapeUtil::ElementsIn(shape);
2894 return absl::c_count_if(
2895 unnested_hlo->fused_parameters(), [&](const HloInstruction* parameter) {
2896 return ShapeUtil::ElementsIn(parameter->shape()) > num_elements;
2897 });
2898 }
2899
2900 // The benefit of unrolling a kInput fusion that is a column reduction comes
2901 // from the vectorization of non-reduction fusion outputs and fusion inputs.
2902 // On the other hand, unrolling can also introduce factors that can cause
2903 // the kernel to run slower. This routine uses a simple heuristic to estimate
2904 // the benefit as well as the overhead of unrolling in order to decide whether
2905 // unrolling is beneficial for the given kInput fusion.
IsUnrollingColumnReductionBeneficial(const HloInstruction * unnested_hlo,const Shape & input_shape,int64 num_kept_minor)2906 bool IsUnrollingColumnReductionBeneficial(const HloInstruction* unnested_hlo,
2907 const Shape& input_shape,
2908 int64 num_kept_minor) {
2909 // TODO(b/122468062): Need further investigate to see whether we can
2910 // remove the constraint on IsPowerOfTwo.
2911 if (!IsPowerOfTwo(static_cast<uint64>(num_kept_minor))) {
2912 return false;
2913 }
2914
2915 if (IsReductionFromOrToContiguousDimensions(*unnested_hlo)) {
2916 return true;
2917 }
2918
2919 CHECK_EQ(unnested_hlo->opcode(), HloOpcode::kFusion);
2920 int64 can_be_vectorized = 0;
2921 int64 cannot_be_vectorized = 0;
2922 const HloInstruction* fused_root = unnested_hlo->fused_expression_root();
2923 ConstHloInstructionSet use_chain_endings;
2924 if (IsReductionFromOrToContiguousDimensions(*fused_root)) {
2925 use_chain_endings.insert(fused_root);
2926 // Atomic.add of the reduction result can't be vectorized.
2927 cannot_be_vectorized++;
2928 } else {
2929 CHECK_EQ(fused_root->opcode(), HloOpcode::kTuple);
2930 for (const HloInstruction* instr : fused_root->operands()) {
2931 if (IsReductionFromOrToContiguousDimensions(*instr)) {
2932 // Atomic.add of the reduction result can't be vectorized.
2933 cannot_be_vectorized++;
2934 } else {
2935 // Write of the non-reduction result can be vectorized.
2936 can_be_vectorized++;
2937 }
2938 use_chain_endings.insert(instr);
2939 }
2940 }
2941 // Fusion inputs that have the same dimension as the reduce input and
2942 // only involve in elementwise operations can be vectorized.
2943 can_be_vectorized += NumInputsInvolveInOnlyElementwiseOps(
2944 unnested_hlo, input_shape, use_chain_endings);
2945 // Fusion inputs with more elements than the reduce op input must participate
2946 // in non-elementwise operations and we assume that they are not vectorizable
2947 // for the purpose of estimating the benefit of unrolling. If the kernel is
2948 // unrolled even with such an assumption, and the accesses to those inputs
2949 // turn out to be vectorizable, the compiler will still vectorize them.
2950 cannot_be_vectorized +=
2951 NumInputsWithMoreElementsThan(unnested_hlo, input_shape);
2952 return can_be_vectorized >= cannot_be_vectorized;
2953 }
2954
2955 } // namespace
2956
ComputeReductionCodegenInfo(const HloInstruction * unnested_hlo,const HloInstruction * first_reduce)2957 ReductionCodegenInfo IrEmitterUnnested::ComputeReductionCodegenInfo(
2958 const HloInstruction* unnested_hlo, const HloInstruction* first_reduce) {
2959 const Shape& input_shape = first_reduce->operand(0)->shape();
2960 ReductionDimensions reduction_dimensions =
2961 GetReductionKindAndContiguousComponents(*first_reduce);
2962 VLOG(10) << "is_row_reduction " << reduction_dimensions.is_row_reduction
2963 << " " << reduction_dimensions.dimensions[0] << " "
2964 << reduction_dimensions.dimensions[1] << " "
2965 << reduction_dimensions.dimensions[2];
2966
2967 std::array<int64, 3> reduction_tiling =
2968 GetReductionTiling(reduction_dimensions);
2969 bool dilated_x =
2970 reduction_dimensions.is_row_reduction ||
2971 !IsUnrollingColumnReductionBeneficial(unnested_hlo, input_shape,
2972 reduction_dimensions.dimensions[2]);
2973
2974 if (!dilated_x && !reduction_dimensions.is_row_reduction) {
2975 // Vectorized loads: a single thread reduces two adjacent columns.
2976 reduction_tiling[2] *= 2;
2977 }
2978
2979 int64 num_threads_y = 1;
2980 int64 num_threads_x = [&] {
2981 if (reduction_dimensions.is_row_reduction) {
2982 return kWarpSize;
2983 }
2984 return std::min(
2985 ThreadsPerBlockLimit(ir_emitter_context_->device_description()),
2986 CeilOfRatio(reduction_dimensions.dimensions[2], reduction_tiling[2]));
2987 }();
2988
2989 KernelMappingScheme mapping_scheme(
2990 reduction_dimensions.dimensions,
2991 {reduction_tiling[0], reduction_tiling[1] * num_threads_y,
2992 reduction_tiling[2] * num_threads_x},
2993 num_threads_y, num_threads_x, dilated_x);
2994 return ReductionCodegenInfo(mapping_scheme,
2995 reduction_dimensions.is_row_reduction);
2996 }
2997
EmitReductionFromOrToContiguousDimensions(HloInstruction * unnested_hlo,absl::Span<HloInstruction * const> output_instructions)2998 Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
2999 HloInstruction* unnested_hlo,
3000 absl::Span<HloInstruction* const> output_instructions) {
3001 bool returns_tuple = output_instructions.size() > 1;
3002 VLOG(10) << "Emitting reduction to vector " << unnested_hlo->ToString();
3003
3004 std::vector<HloInstruction*> reduce_instructions;
3005 InlinedVector<ShapeIndex, 1> reduction_output_shape_indices;
3006 InlinedVector<HloComputation*, 1> reducers;
3007
3008 // Build an initializer thunk to initialize each reduction output.
3009 std::vector<std::unique_ptr<Thunk>> thunks;
3010 for (int i = 0; i < output_instructions.size(); ++i) {
3011 if (!IsReductionFromOrToContiguousDimensions(*output_instructions[i])) {
3012 continue;
3013 }
3014
3015 HloInstruction* output_instruction = output_instructions[i];
3016 reduce_instructions.push_back(output_instruction);
3017 ShapeIndex idx = returns_tuple ? ShapeIndex({i}) : ShapeIndex({});
3018 reduction_output_shape_indices.push_back(idx);
3019 reducers.push_back(output_instruction->to_apply());
3020
3021 TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> initializer_thunk,
3022 BuildInitializerThunk(unnested_hlo, idx));
3023 thunks.push_back(std::move(initializer_thunk));
3024 }
3025
3026 const HloInstruction* first_reduce = reduce_instructions.at(0);
3027 if (output_instructions.size() > 1) {
3028 if (!AreFusedReductionOutputsConsistent(output_instructions,
3029 first_reduce)) {
3030 return InternalError("Inconsistent reduction fusion outputs");
3031 }
3032 }
3033
3034 // Build a kernel thunk to compute all the outputs.
3035 std::unique_ptr<KernelThunk> kernel_thunk =
3036 BuildKernelThunk(unnested_hlo, /*implements_whole_instruction=*/false);
3037
3038 const Shape& input_shape = first_reduce->operand(0)->shape();
3039 // The layout of a reduction input is either set by LayoutAssignment for
3040 // unnested kReduce or by InstructionFusion for fused kReduce.
3041 CHECK(input_shape.has_layout()) << "LayoutAssignment or InstructionFusion "
3042 "doesn't set the input layout of "
3043 << first_reduce->ToString();
3044
3045 ReductionCodegenInfo reduction_info =
3046 ComputeReductionCodegenInfo(unnested_hlo, first_reduce);
3047 const KernelMappingScheme& mapping_scheme =
3048 reduction_info.GetKernelMappingScheme();
3049 LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(),
3050 mapping_scheme.GetThreadsPerBlock());
3051 llvm::Type* index_ty = GetIndexTypeForKernel(
3052 unnested_hlo, launch_dimensions.launch_bound(), &b_);
3053 EmitPrologueForReduction(unnested_hlo, &reduction_info, reduce_instructions,
3054 index_ty);
3055 EmitElementFunction emit_reduction_tile =
3056 [&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc,
3057 llvm::Value* x_loc, int64 x_iter_num) {
3058 EmitTileElementForReduction(unnested_hlo, input_shape,
3059 output_instructions, index, reduction_info,
3060 reducers, x_iter_num);
3061 };
3062
3063 TilingKernelInfo tiling_kernel_info = EmitTilingKernel(
3064 mapping_scheme, index_ty,
3065 [&](llvm::Value* y, llvm::Value* x, const IrArray::Index& index,
3066 const string& loop_name, llvm::Value* tile_height,
3067 llvm::Value* tile_width, KernelSupportLibrary* ksl) {
3068 EmitTile(reduction_info.GetKernelMappingScheme(), index, loop_name, ksl,
3069 &b_, y, x, tile_height, tile_width, emit_reduction_tile);
3070 });
3071 EmitEpilogueForReduction(index_ty, unnested_hlo, reduction_info,
3072 reduce_instructions, reduction_output_shape_indices,
3073 reducers, tiling_kernel_info);
3074
3075 UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
3076 ir_emitter_context_->llvm_module());
3077
3078 thunks.push_back(std::move(kernel_thunk));
3079 auto sequential_thunk =
3080 absl::make_unique<SequentialThunk>(std::move(thunks), unnested_hlo);
3081 AddThunkToThunkSequence(std::move(sequential_thunk));
3082
3083 return Status::OK();
3084 }
3085
EmitConstantGlobals()3086 Status IrEmitterUnnested::EmitConstantGlobals() {
3087 for (const BufferAllocation& allocation :
3088 ir_emitter_context_->buffer_assignment().Allocations()) {
3089 if (!allocation.is_constant()) {
3090 continue;
3091 }
3092
3093 const Literal& literal = llvm_ir::LiteralForConstantAllocation(allocation);
3094 const bool should_emit_initializer = ShouldEmitLiteralInLlvmIr(literal);
3095 llvm::ArrayType* global_type =
3096 llvm::ArrayType::get(b_.getInt8Ty(), allocation.size());
3097 llvm::Constant* initializer =
3098 should_emit_initializer
3099 ? llvm_ir::ConvertLiteralToIrConstant(literal, module_)
3100 : llvm::ConstantAggregateZero::get(global_type);
3101 if (should_emit_initializer) {
3102 VLOG(3) << "Emitted initializer for constant with shape "
3103 << ShapeUtil::HumanString(literal.shape());
3104 }
3105
3106 // These globals will be looked up by name by GpuExecutable so we need to
3107 // give them an external linkage. Not all of their uses are visible in
3108 // the LLVM IR (e.g. TupleThunk) so we can't give then a linkage that
3109 // merely preserves their names (like available_externally), we also need
3110 // to ensure that they stick around even if they're "unused".
3111 //
3112 // We may have to be more more clever here in the future if we notice that
3113 // we're keeping around too many globals because of their linkage.
3114 unsigned global_address_space = llvm_ir::GetGlobalMemoryAddressSpace(
3115 *ir_emitter_context_->llvm_module());
3116 llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable(
3117 global_type, /*isConstant=*/should_emit_initializer,
3118 llvm::GlobalValue::ExternalLinkage,
3119 /*Initializer=*/initializer,
3120 llvm_ir::ConstantBufferAllocationToGlobalName(allocation),
3121 /*TLMode=*/llvm::GlobalValue::NotThreadLocal,
3122 /*AddressSpace=*/global_address_space,
3123 /*isExternallyInitialized=*/false);
3124 global_for_const->setAlignment(kConstantBufferAlignBytes);
3125 ir_emitter_context_->llvm_module()->getGlobalList().push_back(
3126 global_for_const);
3127 }
3128
3129 return Status::OK();
3130 }
3131
3132 // Emits code for slices based on the below structure. An if statement with
3133 // a guarding condition is generated for each ROOT slice.
3134 //
3135 // Pseudo code:
3136 //
3137 // Compute values of slice input operands
3138 //
3139 // Compute guarding_cond0
3140 // if (guarding_cond0) {
3141 // Write to output of slice0
3142 // }
3143 //
3144 // Compute guarding_cond1
3145 // if (guarding_cond1) {
3146 // Write to output of slice1
3147 // }
3148 //
EmitElementForInputFusibleSlices(HloInstruction * unnested_hlo,const llvm_ir::IrArray::Index & index)3149 void IrEmitterUnnested::EmitElementForInputFusibleSlices(
3150 HloInstruction* unnested_hlo, const llvm_ir::IrArray::Index& index) {
3151 VLOG(10) << "Emitting slice input fusion for " << unnested_hlo->ToString();
3152
3153 HloInstruction* slice_or_tuple = unnested_hlo->fused_expression_root();
3154 auto slice_instructions = [&]() -> absl::Span<HloInstruction* const> {
3155 if (slice_or_tuple->opcode() == HloOpcode::kSlice) {
3156 return absl::Span<HloInstruction* const>(&slice_or_tuple, 1);
3157 }
3158 CHECK_EQ(slice_or_tuple->opcode(), HloOpcode::kTuple);
3159 return slice_or_tuple->operands();
3160 }();
3161
3162 // Emit input operand values of slices.
3163 std::vector<llvm::Value*> input_ir_values;
3164 GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
3165 GetNestedComputer());
3166 FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(unnested_hlo),
3167 &elem_emitter);
3168 TF_CHECK_OK(unnested_hlo->fused_expression_root()->Accept(&fused_emitter));
3169 for (const HloInstruction* slice : slice_instructions) {
3170 auto input_generator = fused_emitter.GetGenerator(slice->operand(0));
3171 input_ir_values.push_back(input_generator(index).ValueOrDie());
3172 }
3173
3174 // Emit for slice_instructions.
3175 KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
3176 for (int64 i = 0; i < slice_instructions.size(); ++i) {
3177 HloInstruction* slice = slice_instructions[i];
3178
3179 // guarding_cond := index >= start && index < limit, for each dim.
3180 std::vector<llvm::Value*> index_within_ranges;
3181 for (size_t dim = 0; dim < slice->slice_starts().size(); ++dim) {
3182 CHECK_EQ(slice->slice_strides(dim), 1);
3183 auto larger_or_equal_than_start = b_.CreateICmpSGE(
3184 index.multidim()[dim],
3185 index.GetConstantWithIndexType(slice->slice_starts(dim)));
3186 llvm::Value* smaller_than_limit = b_.CreateICmpSLT(
3187 index.multidim()[dim],
3188 index.GetConstantWithIndexType(slice->slice_limits(dim)));
3189 llvm::Value* within_range =
3190 b_.CreateAnd(larger_or_equal_than_start, smaller_than_limit);
3191 index_within_ranges.push_back(within_range);
3192 }
3193 llvm::Value* guarding_cond = b_.CreateAnd(index_within_ranges);
3194
3195 auto emit_slice_elem_func = [&] {
3196 const std::vector<llvm::Value*>& src_multidim = index.multidim();
3197 std::vector<llvm::Value*> dst_multidim(src_multidim.size());
3198 for (size_t dim = 0; dim < src_multidim.size(); ++dim) {
3199 dst_multidim[dim] =
3200 Sub(src_multidim[dim],
3201 index.GetConstantWithIndexType(slice->slice_starts(dim)));
3202 }
3203 ShapeIndex shape_index = (slice_or_tuple->opcode() == HloOpcode::kSlice)
3204 ? ShapeIndex()
3205 : ShapeIndex({i});
3206 llvm_ir::IrArray src_ir_array =
3207 GetIrArray(*unnested_hlo, *unnested_hlo, shape_index);
3208 IrArray::Index slice_dst_index(dst_multidim, slice->shape(),
3209 index.GetType());
3210 llvm::Value* dst_addr = src_ir_array.EmitArrayElementAddress(
3211 slice_dst_index, &b_, "slice.dest");
3212 b_.CreateStore(input_ir_values[i], dst_addr);
3213 };
3214
3215 ksl.If(StrCat("slice", i), guarding_cond, emit_slice_elem_func);
3216 }
3217 }
3218
EmitInputFusibleNonStridedSlices(HloInstruction * unnested_hlo)3219 Status IrEmitterUnnested::EmitInputFusibleNonStridedSlices(
3220 HloInstruction* unnested_hlo) {
3221 constexpr int unroll_factor = 1;
3222 std::unique_ptr<KernelThunk> kernel_thunk = BuildKernelThunk(
3223 unnested_hlo, /*implements_whole_instruction=*/true, unroll_factor);
3224
3225 TF_ASSIGN_OR_RETURN(Shape element_shape,
3226 GetConsistentInputShapeForRootSlices(*unnested_hlo));
3227 LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
3228 element_shape, ir_emitter_context_->device_description(), unroll_factor);
3229 UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
3230 ir_emitter_context_->llvm_module());
3231
3232 Status emit_status =
3233 ParallelLoopEmitter(
3234 [&](const llvm_ir::IrArray::Index index) -> Status {
3235 EmitElementForInputFusibleSlices(unnested_hlo, index);
3236 return Status::OK();
3237 },
3238 element_shape, launch_dimensions, &b_)
3239 .EmitLoop(IrName(unnested_hlo),
3240 GetIndexTypeForKernel(
3241 unnested_hlo, launch_dimensions.launch_bound(), &b_));
3242
3243 thunk_sequence_->emplace_back(std::move(kernel_thunk));
3244
3245 return emit_status;
3246 }
3247
3248 } // namespace gpu
3249 } // namespace xla
3250