• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
17 
18 #include <algorithm>
19 #include <cstring>
20 #include <iterator>
21 #include <memory>
22 #include <string>
23 #include <vector>
24 
25 #include "absl/algorithm/container.h"
26 #include "absl/container/inlined_vector.h"
27 #include "absl/memory/memory.h"
28 #include "absl/strings/str_cat.h"
29 #include "absl/types/optional.h"
30 #include "absl/types/span.h"
31 #include "llvm/ADT/StringRef.h"
32 #include "llvm/IR/BasicBlock.h"
33 #include "llvm/IR/Function.h"
34 #include "llvm/IR/IRBuilder.h"
35 #include "llvm/IR/Instructions.h"
36 #include "llvm/IR/LLVMContext.h"
37 #include "llvm/IR/Module.h"
38 #include "tensorflow/compiler/xla/layout_util.h"
39 #include "tensorflow/compiler/xla/literal.h"
40 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
41 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
42 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
43 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
44 #include "tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h"
45 #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h"
46 #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h"
47 #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h"
48 #include "tensorflow/compiler/xla/service/gpu/for_thunk.h"
49 #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
50 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"
51 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
52 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
53 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
54 #include "tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h"
55 #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h"
56 #include "tensorflow/compiler/xla/service/gpu/memset_thunk.h"
57 #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
58 #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
59 #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
60 #include "tensorflow/compiler/xla/service/gpu/replica_id_thunk.h"
61 #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
62 #include "tensorflow/compiler/xla/service/gpu/target_util.h"
63 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
64 #include "tensorflow/compiler/xla/service/gpu/tuple_thunk.h"
65 #include "tensorflow/compiler/xla/service/gpu/while_thunk.h"
66 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
67 #include "tensorflow/compiler/xla/service/hlo_computation.h"
68 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
69 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
70 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
71 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
72 #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
73 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
74 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
75 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
76 #include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h"
77 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
78 #include "tensorflow/compiler/xla/service/name_uniquer.h"
79 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
80 #include "tensorflow/compiler/xla/service/while_loop_analysis.h"
81 #include "tensorflow/compiler/xla/shape_util.h"
82 #include "tensorflow/compiler/xla/status_macros.h"
83 #include "tensorflow/compiler/xla/types.h"
84 #include "tensorflow/compiler/xla/util.h"
85 #include "tensorflow/compiler/xla/window_util.h"
86 #include "tensorflow/compiler/xla/xla_data.pb.h"
87 #include "tensorflow/core/lib/core/bits.h"
88 #include "tensorflow/core/lib/core/status.h"
89 #include "tensorflow/core/platform/logging.h"
90 
91 namespace xla {
92 namespace gpu {
93 
94 namespace {
95 
96 using absl::InlinedVector;
97 using absl::nullopt;
98 using absl::optional;
99 using absl::StrCat;
100 using llvm_ir::IrArray;
101 using llvm_ir::IrName;
102 
103 const auto kDimX = KernelMappingScheme::DimX;
104 const auto kDimY = KernelMappingScheme::DimY;
105 const auto kDimZ = KernelMappingScheme::DimZ;
106 const auto kDimTot = KernelMappingScheme::DimTot;
107 
108 // If a dimensions is smaller than this, untiled transposition may be more
109 // efficient.
110 const int64 kMinDimensionToTransposeTiled = 16;
111 
112 // Updates the launch dimensions in "thunk" and annotate the launch dimensions
113 // of the corresponding IR kernel in "llvm_module".
114 // Precondition: "thunk" must be a KernelThunk.
UpdateLaunchDimensions(const LaunchDimensions & launch_dims,Thunk * thunk,llvm::Module * llvm_module)115 void UpdateLaunchDimensions(const LaunchDimensions& launch_dims, Thunk* thunk,
116                             llvm::Module* llvm_module) {
117   CHECK(Thunk::Kind::kKernel == thunk->kind());
118   KernelThunk* kernel_thunk = static_cast<KernelThunk*>(thunk);
119   kernel_thunk->SetLaunchDimensions(launch_dims);
120 
121   // Add __launch_bounds__ to metadata. This limits registers per thread to
122   // avoid out-of-resources launching errors.
123   llvm::NamedMDNode* nvvm_annotations_node =
124       llvm_module->getOrInsertNamedMetadata("nvvm.annotations");
125   llvm::Function* ir_kernel =
126       llvm_module->getFunction(kernel_thunk->kernel_name().c_str());
127   llvm::LLVMContext& llvm_context = llvm_module->getContext();
128   llvm::ConstantInt* threads_per_block_ir_value = llvm::ConstantInt::get(
129       llvm::IntegerType::get(llvm_context, /*NumBits=*/32),
130       launch_dims.threads_per_block());
131   // Our launch bounds are exact, so we can specify them as reqntidx rather than
132   // maxntidx.
133   nvvm_annotations_node->addOperand(llvm::MDNode::get(
134       llvm_context,
135       {llvm::ConstantAsMetadata::get(ir_kernel),
136        llvm::MDString::get(llvm_context, "reqntidx"),
137        llvm::ConstantAsMetadata::get(threads_per_block_ir_value)}));
138 }
139 
140 }  // namespace
141 
IrEmitterUnnested(const HloModuleConfig & hlo_module_config,const HloComputation * hlo_computation,IrEmitterContext * ir_emitter_context)142 IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
143                                      const HloComputation* hlo_computation,
144                                      IrEmitterContext* ir_emitter_context)
145     : IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/false),
146       hlo_computation_(hlo_computation) {
147   // Initialize thunk_sequence_ to an empty list of thunks.
148   thunk_sequence_.reset(new ThunkSequence());
149 }
150 
Postprocess(HloInstruction * hlo)151 Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) {
152   bindings_.UnbindAllLocalIrValues();
153   return DfsHloVisitor::Postprocess(hlo);
154 }
155 
BuildKernelPrototype(const HloInstruction & inst,absl::Span<const BufferAllocation * const> args)156 llvm::Function* IrEmitterUnnested::BuildKernelPrototype(
157     const HloInstruction& inst,
158     absl::Span<const BufferAllocation* const> args) {
159   // Compute the kernel name. The opcode string may contain "-" which cannot be
160   // in a PTX function name, so sanitize the name before uniquifying it.
161   string kernel_name = ir_emitter_context_->name_uniquer()->GetUniqueName(
162       llvm_ir::SanitizeFunctionName(inst.name()));
163 
164   // Create the kernel and add it to the module.
165   llvm::Module* module = ir_emitter_context_->llvm_module();
166   llvm::LLVMContext& context = module->getContext();
167   llvm::FunctionType* kernel_type = llvm::FunctionType::get(
168       /*Result=*/llvm::Type::getVoidTy(context),
169       std::vector<llvm::Type*>(args.size(), b_.getInt8PtrTy()),
170       /*isVarArg=*/false);
171   llvm::Function* kernel =
172       llvm::Function::Create(kernel_type, llvm::GlobalValue::ExternalLinkage,
173                              kernel_name.c_str(), module);
174 
175   // Add dereferenceable and alignment information to each of the kernel's
176   // parameters.
177   auto arg_it = kernel->arg_begin();
178   for (size_t arg_no = 0; arg_no < args.size(); ++arg_no) {
179     const BufferAllocation* alloc = args[arg_no];
180     llvm::Argument* fn_arg = &*arg_it;
181     ++arg_it;
182 
183     kernel->addDereferenceableAttr(arg_no + 1, alloc->size());
184 
185     const int64 alignment = [&] {
186       if (alloc->is_entry_computation_parameter()) {
187         return kEntryParameterAlignBytes;
188       } else if (alloc->is_constant()) {
189         return kConstantBufferAlignBytes;
190       } else {
191         return kXlaAllocatedBufferAlignBytes;
192       }
193     }();
194 
195     kernel->addParamAttr(
196         arg_no,
197         llvm::Attribute::get(context, llvm::Attribute::Alignment, alignment));
198 
199     if (alloc->IsPreallocatedTempBuffer()) {
200       fn_arg->setName("temp_buf");
201     } else {
202       fn_arg->setName(StrCat("alloc", alloc->index()));
203     }
204   }
205 
206   AnnotateFunctionAsGpuKernel(module, kernel, &b_);
207 
208   // TODO(b/65380986): Investigate if adding fast math flags for generated
209   // kernels makes sense.
210 
211   // Update the insert point to the entry basic block.
212   llvm::BasicBlock* entry_bb =
213       llvm::BasicBlock::Create(context, /*Name=*/"entry", /*Parent=*/kernel);
214 
215   // Emit a "return void" at entry_bb's end, and set the insert point before
216   // that return instruction.
217   b_.SetInsertPoint(llvm::ReturnInst::Create(context, entry_bb));
218 
219   return kernel;
220 }
221 
222 namespace {
223 // Computes the maximum valid unroll factor for a given instruction.
ComputeMaxUnrollFactor(const HloInstruction * hlo)224 int ComputeMaxUnrollFactor(const HloInstruction* hlo) {
225   int max_unroll_factor = hlo->GetModule()
226                               ->config()
227                               .debug_options()
228                               .xla_gpu_max_kernel_unroll_factor();
229 
230   // Find the largest possible power of two to unroll by.
231   // TODO(kramerb): Make this smarter.
232   const Shape& element_shape = hlo->IsMultiOutputFusion()
233                                    ? ShapeUtil::GetSubshape(hlo->shape(), {0})
234                                    : hlo->shape();
235   int64 num_elements = ShapeUtil::ElementsIn(element_shape);
236   for (int i = max_unroll_factor; i > 1; i /= 2) {
237     if (num_elements % i == 0) {
238       return i;
239     }
240   }
241 
242   // Cannot unroll.
243   return 1;
244 }
245 
246 // Returns the llvm type for the indices used in the kernel that contains the
247 // hlo instruction. Such indices include the index for the parallel loop and
248 // the indices for the tensors accessed by the kernel. The return type is i32
249 // iff the following conditions are met:
250 //  . The launch_size of the kernel is within the range of i32.
251 //  . The sizes of all the tensors accessed within the kernel are within the
252 //    range of i32.
253 // Otherwise, the return type is i64.
GetIndexTypeForKernel(const HloInstruction * hlo,int64 launch_size,llvm::IRBuilder<> * b)254 llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size,
255                                   llvm::IRBuilder<>* b) {
256   // Find the unnested hlo instruction for which the kernel is generated for.
257   const HloInstruction* unnested_hlo = hlo;
258   const HloComputation* computation = hlo->parent();
259   if (computation->IsFusionComputation()) {
260     unnested_hlo = computation->FusionInstruction();
261   }
262 
263   auto shape_in_range = [&](const Shape& s) {
264     bool in_range = true;
265     ShapeUtil::ForEachSubshape(s, [&](const Shape& sub_shape,
266                                       const ShapeIndex& /*index*/) {
267       if (sub_shape.IsArray() && !IsInt32(ShapeUtil::ElementsIn(sub_shape))) {
268         in_range = false;
269       }
270     });
271 
272     return in_range;
273   };
274 
275   llvm::Type* i64_ty = b->getInt64Ty();
276   // Check launch dimension
277   if (!IsInt32(launch_size)) {
278     return i64_ty;
279   }
280 
281   // Check the size of result tensors
282   if (!shape_in_range(unnested_hlo->shape())) {
283     return i64_ty;
284   }
285 
286   auto hlo_shape_in_range = [&](const HloInstruction* operand) -> bool {
287     return shape_in_range(operand->shape());
288   };
289 
290   // Check the size of input tensors
291   if (!absl::c_all_of(unnested_hlo->operands(), hlo_shape_in_range)) {
292     return i64_ty;
293   }
294 
295   // Check the size of the internal result tensors
296   if (unnested_hlo->opcode() == HloOpcode::kFusion) {
297     if (!absl::c_all_of(
298             unnested_hlo->fused_instructions_computation()->instructions(),
299             hlo_shape_in_range)) {
300       return i64_ty;
301     }
302   }
303 
304   return b->getInt32Ty();
305 }
306 
307 // Gets the input shape of the ROOT slices, which will be used as the kernel
308 // launch dims. The slice input fusion requires the input shapes of the ROOT
309 // slices to be the same although the (slice) output shapes can be different.
310 //
311 // Returns the input shape of the ROOT slices if all the input shapes of ROOT
312 // slices are the same and the slices are non-strided. Otherwise, returns
313 // FailedPrecondition.
GetConsistentInputShapeForRootSlices(const HloInstruction & fusion)314 StatusOr<Shape> GetConsistentInputShapeForRootSlices(
315     const HloInstruction& fusion) {
316   if (!IsInputFusibleSlices(fusion, /*verify_no_strides=*/true)) {
317     return FailedPrecondition(
318         "Unsupported root for slice input fusion. "
319         "Only non-strided slices are supported.");
320   }
321 
322   const HloInstruction& root = *fusion.fused_expression_root();
323   if (root.opcode() == HloOpcode::kSlice) {
324     return root.operands()[0]->shape();
325   }
326 
327   CHECK_EQ(root.opcode(), HloOpcode::kTuple);
328   const Shape& first_slice_operand_shape =
329       root.operands()[0]->operands()[0]->shape();
330   for (size_t i = 1; i < root.operands().size(); ++i) {
331     const HloInstruction* slice = root.operands()[i];
332     const Shape& operand_shape = slice->operands()[0]->shape();
333     if (!ShapeUtil::EqualIgnoringElementType(first_slice_operand_shape,
334                                              operand_shape)) {
335       return FailedPrecondition(
336           "Fused slices do not have the same input shape, fused computation = "
337           "%s.",
338           root.parent()->name());
339     }
340   }
341 
342   return first_slice_operand_shape;
343 }
344 
345 }  // namespace
346 
DefaultAction(HloInstruction * hlo)347 Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) {
348   return IrEmitter::DefaultAction(hlo);
349 }
350 
HandleDot(HloInstruction * dot)351 Status IrEmitterUnnested::HandleDot(HloInstruction* dot) {
352   AddThunkToThunkSequence(
353       BuildKernelThunk(dot, /*implements_whole_instruction=*/true));
354   return IrEmitter::HandleDot(dot);
355 }
356 
HandleConditional(HloInstruction * conditional)357 Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) {
358   AddThunkToThunkSequence(BuildConditionalThunk(conditional));
359   return Status::OK();
360 }
361 
HandleConvolution(HloInstruction * convolution)362 Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution) {
363   AddThunkToThunkSequence(
364       BuildKernelThunk(convolution, /*implements_whole_instruction=*/true));
365   return IrEmitter::HandleConvolution(convolution);
366 }
367 
HandleCustomCall(HloInstruction * custom_call)368 Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
369   return ThunkEmitter(this).HandleCustomCall(custom_call);
370 }
371 
HandleFft(HloInstruction * fft)372 Status IrEmitterUnnested::HandleFft(HloInstruction* fft) {
373   return ThunkEmitter(this).HandleFft(fft);
374 }
375 
HandleTriangularSolve(HloInstruction * hlo)376 Status IrEmitterUnnested::HandleTriangularSolve(HloInstruction* hlo) {
377   return ThunkEmitter(this).HandleTriangularSolve(hlo);
378 }
379 
HandleFusion(HloInstruction * fusion)380 Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
381   HloInstruction* root = fusion->fused_expression_root();
382   if (fusion->IsInputFusion()) {
383     switch (root->opcode()) {
384       case HloOpcode::kScatter: {
385         std::vector<std::unique_ptr<Thunk>> thunks;
386         // The initialization from 'operand' is using different loop bounds, so
387         // emit it in a separate kernel. Treat it like a loop fusion, writing to
388         // the output buffer.
389         {
390           int unroll_factor = ComputeMaxUnrollFactor(fusion);
391           thunks.push_back(BuildKernelThunk(
392               fusion, /*implements_whole_instruction=*/false, unroll_factor));
393           GpuElementalIrEmitter operand_elemental_emitter(
394               hlo_module_config_, ir_emitter_context_->llvm_module(), &b_,
395               GetNestedComputer());
396           FusedIrEmitter operand_fused_emitter(
397               GetGeneratorForOperandIrArrays(fusion),
398               &operand_elemental_emitter);
399           TF_RETURN_IF_ERROR(
400               root->mutable_operand(0)->Accept(&operand_fused_emitter));
401 
402           TF_RETURN_IF_ERROR(EmitTargetElementLoopInThunk(
403               *fusion, operand_fused_emitter.GetGenerator(root->operand(0)),
404               static_cast<KernelThunk*>(thunks.back().get())));
405         }
406 
407         // Now build the actual scatter, reading and writing to the freshly
408         // filled output buffer.
409         {
410           thunks.push_back(
411               BuildKernelThunk(fusion,
412                                /*implements_whole_instruction=*/false));
413           // Spin up a new fused emitter for the scatter kernel and emit it.
414           GpuElementalIrEmitter scatter_elemental_emitter(
415               hlo_module_config_, ir_emitter_context_->llvm_module(), &b_,
416               GetNestedComputer());
417           FusedIrEmitter scatter_fused_emitter(
418               GetGeneratorForOperandIrArrays(fusion),
419               &scatter_elemental_emitter);
420           TF_RETURN_IF_ERROR(root->Accept(&scatter_fused_emitter));
421           TF_RETURN_IF_ERROR(EmitScatter(
422               thunks.back().get(), root,
423               /*scatter_indices_gen=*/
424               scatter_fused_emitter.GetGenerator(root->operand(1)),
425               /*updates_gen=*/
426               scatter_fused_emitter.GetGenerator(root->operand(2))));
427         }
428         AddThunkToThunkSequence(
429             absl::make_unique<SequentialThunk>(std::move(thunks), fusion));
430         return Status::OK();
431       }
432       // In the case of root tuple, it can be either reduce or slice input
433       // fusion.
434       case HloOpcode::kTuple: {
435         if (IsInputFusibleSlices(*fusion)) {
436           return EmitInputFusibleNonStridedSlices(fusion);
437         }
438 
439         CHECK_GE(root->operand_count(), 1);
440         return EmitReductionFromOrToContiguousDimensions(fusion,
441                                                          root->operands());
442       }
443       case HloOpcode::kReduce: {
444         // HandleFusion specializes reduction from a multi-dimensional array to
445         // a 1D array. The specialized version requires a initializer thunk that
446         // initializes the output array to the initial value of the reduce.
447         if (root->shape().IsTuple()) {
448           // TODO(b/129089333): Support tiled vectorized variadic reduce.
449           return Unimplemented(
450               "Vectorized variadic reduce is not supported on GPU");
451         }
452         return EmitReductionFromOrToContiguousDimensions(fusion, {root});
453       }
454       case HloOpcode::kSlice: {
455         return EmitInputFusibleNonStridedSlices(fusion);
456       }
457       default:
458         LOG(FATAL) << "Bad opcode for input fusion: "
459                    << fusion->fused_expression_root()->opcode();
460     }
461   } else if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(
462                  fusion, ir_emitter_context_->buffer_assignment())) {
463     // Fusion node with dynamic-update-slice as the root where the op's input
464     // (i.e. array to update) shares the same slice as its output.  In this case
465     // we have a special algorithm that modifies the output in place without
466     // touching the un-updated elements.
467 
468     // Set up kernel thunk and fused ir emitter.
469     std::unique_ptr<KernelThunk> fusion_thunk =
470         BuildKernelThunk(fusion, /*implements_whole_instruction=*/true);
471     GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
472                                             ir_emitter_context_->llvm_module(),
473                                             &b_, GetNestedComputer());
474 
475     // Shape of the dynamic-update-slice's "update" operand.
476     Shape update_shape = root->operand(1)->shape();
477 
478     // Array to write into.  Because this is an in-place operation, this is the
479     // same as operand 0's array.
480     IrArray output_array = GetIrArray(*fusion, *fusion);
481 
482     LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
483         update_shape, ir_emitter_context_->device_description());
484     UpdateLaunchDimensions(launch_dimensions, fusion_thunk.get(),
485                            ir_emitter_context_->llvm_module());
486     AddThunkToThunkSequence(std::move(fusion_thunk));
487 
488     return llvm_ir::EmitParallelFusedDynamicUpdateSliceInPlace(
489         fusion, GetGeneratorForOperandIrArrays(fusion), output_array,
490         &elemental_emitter, launch_dimensions, &b_);
491   }
492 
493   CHECK_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kLoop)
494       << ": " << fusion->ToString();
495 
496   if (CheckAndEmitHloWithTile021(fusion)) {
497     return Status::OK();
498   }
499 
500   return IrEmitter::HandleFusion(fusion);
501 }
502 
HandleCopy(HloInstruction * copy)503 Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
504   CHECK(ShapeUtil::Compatible(copy->operand(0)->shape(), copy->shape()));
505   const BufferAssignment& buffer_assignment =
506       ir_emitter_context_->buffer_assignment();
507   if (LayoutUtil::Equal(copy->operand(0)->shape().layout(),
508                         copy->shape().layout()) &&
509       buffer_assignment.GetUniqueTopLevelSlice(copy->operand(0)).ok()) {
510     // Copy the operand into the output if it's not the same buffer already.
511     auto operand_buffer = GetAllocationSlice(*copy->operand(0));
512     auto destination_buffer = GetAllocationSlice(*copy);
513     if (operand_buffer != destination_buffer) {
514       AddThunkToThunkSequence(absl::make_unique<DeviceToDeviceCopyThunk>(
515           /*source_address=*/operand_buffer,
516           /*destination_buffer=*/destination_buffer,
517           /*mem_size=*/
518           ByteSizeOf(copy->operand(0)->shape()), copy));
519     }
520     return Status::OK();
521   }
522   if (CheckAndEmitHloWithTile021(copy)) {
523     return Status::OK();
524   }
525 
526   return IrEmitter::HandleCopy(copy);
527 }
528 
EmitExtraOutputsForReduce(const HloInstruction * unnested_hlo,const IrArray::Index & index,bool use_linear_index,absl::Span<const std::pair<llvm_ir::ElementGenerator,ShapeIndex>> extra_output_gens)529 Status IrEmitterUnnested::EmitExtraOutputsForReduce(
530     const HloInstruction* unnested_hlo, const IrArray::Index& index,
531     bool use_linear_index,
532     absl::Span<const std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
533         extra_output_gens) {
534   for (int i = 0; i != extra_output_gens.size(); ++i) {
535     llvm::Value* extra_output_address =
536         GetIrArray(*unnested_hlo, *unnested_hlo, extra_output_gens[i].second)
537             .EmitArrayElementAddress(index, &b_, "extra_output_element_address",
538                                      use_linear_index);
539     TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value,
540                         extra_output_gens[i].first(index));
541     Store(extra_output_ir_value, extra_output_address);
542   }
543   return Status::OK();
544 }
545 
HandleReduce(HloInstruction * reduce)546 Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
547   if (IsReductionFromOrToContiguousDimensions(*reduce) &&
548       reduce->shape().IsArray()) {
549     return EmitReductionFromOrToContiguousDimensions(reduce, {reduce});
550   }
551 
552   return IrEmitter::HandleReduce(reduce);
553 }
554 
HandleTuple(HloInstruction * tuple)555 Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) {
556   // For the root node of the entry computation we can elide writing the tuple
557   // buffer. We can always figure out the contents of the tuples from buffer
558   // assignment because we insert copies to ensure non-ambiguous output buffers.
559   // GpuExecutable never reads the tuple buffer.
560   if (tuple ==
561       tuple->parent()->parent()->entry_computation()->root_instruction()) {
562     return Status::OK();
563   }
564   bool all_tuple_elements_have_buffer =
565       absl::c_all_of(tuple->operands(), [&](HloInstruction* tuple_element) {
566         return ir_emitter_context_->buffer_assignment()
567             .GetUniqueTopLevelSlice(tuple_element)
568             .ok();
569       });
570   // TODO(b/111689850): This logic isn't quite correct.
571   //
572   // Tuples (especially tuples that are the final result of a computation) can
573   // be so huge that if we were to emit a kernel that took each tuple element as
574   // a parameter, we would exceed the max allowable number of parameters to a
575   // GPU kernel, b/31336476. As an optimization, if all tuple elements have a
576   // buffer, we collect their buffer addresses in a host array, and then copy
577   // that array to the tuple's buffer.
578   //
579   // Some tuple elements might not have an unambiguous buffer (like the result
580   // of a select-tuple). In that case, we fall back to emitting kernels which
581   // have access to their buffer addresses in code.
582   if (all_tuple_elements_have_buffer) {
583     std::vector<BufferAllocation::Slice> tuple_element_buffers;
584     for (const HloInstruction* tuple_element : tuple->operands()) {
585       tuple_element_buffers.push_back(GetAllocationSlice(*tuple_element));
586     }
587     AddThunkToThunkSequence(absl::make_unique<TupleThunk>(
588         tuple_element_buffers, GetAllocationSlice(*tuple), tuple));
589     return Status::OK();
590   }
591   AddThunkToThunkSequence(
592       BuildKernelThunk(tuple, /*implements_whole_instruction=*/true));
593   return IrEmitter::HandleTuple(tuple);
594 }
595 
HandleGetTupleElement(HloInstruction *)596 Status IrEmitterUnnested::HandleGetTupleElement(HloInstruction*) {
597   // GetTupleElement IR is emitted in the IR context of the user instruction,
598   // and so we do not build a kernel for GetTupleElement instructions.
599   return Status::OK();
600 }
601 
HandleSelectAndScatter(HloInstruction * select_and_scatter)602 Status IrEmitterUnnested::HandleSelectAndScatter(
603     HloInstruction* select_and_scatter) {
604   CHECK_EQ(select_and_scatter->operand_count(), 3);
605   const auto* operand = select_and_scatter->operand(0);
606   const auto* source = select_and_scatter->operand(1);
607   const Window& window = select_and_scatter->window();
608   PrimitiveType operand_element_type = operand->shape().element_type();
609   const int64 rank = operand->shape().rank();
610   CHECK_EQ(rank, source->shape().rank());
611   CHECK_EQ(rank, window.dimensions_size());
612 
613   TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> initializer_thunk,
614                       BuildInitializerThunk(select_and_scatter));
615   std::vector<std::unique_ptr<Thunk>> thunks;
616   thunks.push_back(std::move(initializer_thunk));
617   thunks.push_back(BuildKernelThunk(select_and_scatter,
618                                     /*implements_whole_instruction=*/false));
619   std::unique_ptr<SequentialThunk> select_and_scatter_thunk =
620       absl::make_unique<SequentialThunk>(std::move(thunks), select_and_scatter);
621 
622   // TODO(b/31410564): Implement dilation rate for select-and-scatter.
623   if (window_util::HasDilation(window)) {
624     return Unimplemented(
625         "Dilation for SelectAndScatter not implemented on GPU.");
626   }
627 
628   LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
629       source->shape(), ir_emitter_context_->device_description());
630   llvm::Type* index_type = GetIndexTypeForKernel(
631       select_and_scatter, launch_dimensions.launch_bound(), &b_);
632   auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
633     return llvm::ConstantInt::get(index_type, c);
634   };
635 
636   // kSelectAndScatter is implemented as two kernel launches: the first launch
637   // initializes the output array to the given initial value,
638   // and the second accumulates the "source" matrix to the
639   // selected elements in the output array. The first launch is already
640   // implemented by the initializer thunk generated earlier, so this function
641   // only needs to take care of the select-and-scatter part.
642   //
643   // Pseudo code for select-and-scatter:
644   //
645   // for (coordinates S in the source):  # This loop is parallel.
646   //   initialized_flag = false
647   //   for (coordinates W in the window):
648   //     I = S * stride + W - pad_low
649   //     if I within bounds of operand:
650   //       if !(initialized_flag and select(selected_value, operand(I))):
651   //         selected_value = operand(I)
652   //         selected_index = I
653   //         initialized_flag = true
654   //   output(selected_index) = scatter(output(selected_index), source(S))
655   auto loop_body_emitter = [=](const IrArray::Index& source_index) -> Status {
656     // Allocate space to keep the currently selected value, its index, and a
657     // boolean flag if the value is initialized. The initialized_flag is set
658     // false.
659     llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry(
660         llvm_ir::PrimitiveTypeToIrType(operand_element_type,
661                                        ir_emitter_context_->llvm_module()),
662         "selected_value_address", &b_);
663     llvm::Value* selected_index_address =
664         llvm_ir::EmitAllocaAtFunctionEntryWithCount(
665             index_type, index_typed_constant(rank), "selected_index_address",
666             &b_);
667     llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry(
668         b_.getInt1Ty(), "initialized_flag_address", &b_);
669     Store(b_.getInt1(false), initialized_flag_address);
670 
671     // Create the inner loop to iterate over the window.
672     llvm_ir::ForLoopNest window_loops(IrName(select_and_scatter, "inner"), &b_,
673                                       index_type);
674     DimensionVector window_size;
675     for (const auto& dim : window.dimensions()) {
676       window_size.push_back(dim.size());
677       CHECK_GT(dim.size(), 0);
678     }
679     const IrArray::Index window_index = window_loops.AddLoopsForShape(
680         ShapeUtil::MakeShape(operand_element_type, window_size), "window");
681     llvm_ir::SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(),
682                                    &b_);
683 
684     // Compute the operand index to visit and evaluate the condition whether the
685     // operand index is within the bounds. The unsigned comparison includes
686     // checking whether the operand index >= 0.
687     std::vector<llvm::Value*> operand_multi_index(source_index.size());
688     llvm::Value* in_bounds_condition = b_.getInt1(true);
689     for (int64 i = 0; i < rank; ++i) {
690       llvm::Value* strided_index = NSWMul(
691           source_index[i], index_typed_constant(window.dimensions(i).stride()));
692       operand_multi_index[i] =
693           NSWSub(NSWAdd(strided_index, window_index[i]),
694                  index_typed_constant(window.dimensions(i).padding_low()));
695       llvm::Value* index_condition = ICmpULT(
696           operand_multi_index[i],
697           index_typed_constant(ShapeUtil::GetDimension(operand->shape(), i)));
698       in_bounds_condition = And(in_bounds_condition, index_condition);
699     }
700     CHECK(in_bounds_condition != nullptr);
701 
702     // Only need to do something if the operand index is within the bounds.
703     // First check if the initialized_flag is set.
704     llvm_ir::LlvmIfData if_in_bounds =
705         llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
706     llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, &b_);
707     llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse(
708         Load(initialized_flag_address), "initialized", &b_);
709 
710     // If the initialized_flag is false, initialize the selected value and index
711     // with the currently visiting operand.
712     llvm_ir::SetToFirstInsertPoint(if_initialized.false_block, &b_);
713     const auto save_operand_index = [&](const IrArray::Index& operand_index) {
714       for (int64 i = 0; i < rank; ++i) {
715         llvm::Value* selected_index_address_slot =
716             InBoundsGEP(selected_index_address, {b_.getInt32(i)});
717         Store(operand_index[i], selected_index_address_slot);
718       }
719     };
720     IrArray operand_array = GetIrArray(*operand, *select_and_scatter);
721     IrArray::Index operand_index(operand_multi_index, operand->shape(),
722                                  index_type);
723     llvm::Value* operand_data =
724         operand_array.EmitReadArrayElement(operand_index, &b_);
725     Store(operand_data, selected_value_address);
726     save_operand_index(operand_index);
727     Store(b_.getInt1(true), initialized_flag_address);
728 
729     // If the initialized_flag is true, call the `select` function to
730     // potentially update the selected value and index with the currently
731     // visiting operand.
732     llvm_ir::SetToFirstInsertPoint(if_initialized.true_block, &b_);
733     llvm::Value* operand_address =
734         operand_array.EmitArrayElementAddress(operand_index, &b_);
735     llvm::Value* select_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
736         llvm_ir::PrimitiveTypeToIrType(PRED,
737                                        ir_emitter_context_->llvm_module()),
738         "select_return_buffer", &b_);
739     TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
740         *select_and_scatter->select(),
741         {selected_value_address, operand_address}, select_return_buffer));
742     llvm::Value* result = Load(select_return_buffer);
743 
744     // If the 'select' function returns false, update the selected value and the
745     // index to the currently visiting operand.
746     llvm::Value* cond = ICmpNE(
747         result,
748         llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(
749                                    PRED, ir_emitter_context_->llvm_module()),
750                                0),
751         "boolean_predicate");
752     llvm_ir::LlvmIfData if_select_lhs =
753         llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_);
754     llvm_ir::SetToFirstInsertPoint(if_select_lhs.false_block, &b_);
755     Store(Load(operand_address), selected_value_address);
756     save_operand_index(operand_index);
757 
758     // After iterating over the window elements, scatter the source element to
759     // the selected index of the output. The value we store at the output
760     // location is computed by calling the `scatter` function with the source
761     // value and the current output value.
762     llvm_ir::SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(),
763                                    &b_);
764     std::vector<llvm::Value*> selected_multi_index;
765     for (int64 i = 0; i < rank; ++i) {
766       llvm::Value* selected_index_address_slot =
767           InBoundsGEP(selected_index_address, {b_.getInt32(i)});
768       selected_multi_index.push_back(Load(selected_index_address_slot));
769     }
770     llvm::Value* source_value_address =
771         GetIrArray(*source, *select_and_scatter)
772             .EmitArrayElementAddress(source_index, &b_);
773     IrArray::Index selected_index(selected_multi_index,
774                                   select_and_scatter->shape(),
775                                   operand_index.GetType());
776     llvm::Value* output_value_address =
777         GetIrArray(*select_and_scatter, *select_and_scatter)
778             .EmitArrayElementAddress(selected_index, &b_);
779     return EmitAtomicOperationForNestedComputation(
780         *select_and_scatter->scatter(), output_value_address,
781         source_value_address);
782   };
783 
784   UpdateLaunchDimensions(
785       launch_dimensions,
786       // IrEmitterUnnested implements kSelectAndScatter as a SequentialThunk
787       // consisting of two thunks, an initializer KernelThunk that initializes
788       // the output and another KernelThunk that accumulates the scattered
789       // elements.
790       select_and_scatter_thunk->thunks().back().get(),
791       ir_emitter_context_->llvm_module());
792   AddThunkToThunkSequence(std::move(select_and_scatter_thunk));
793   return ParallelLoopEmitter(loop_body_emitter, source->shape(),
794                              launch_dimensions, &b_)
795       .EmitLoop(IrName(select_and_scatter), index_type);
796 }
797 
HandleWhile(HloInstruction * xla_while)798 Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while) {
799   HloComputation* condition = xla_while->while_condition();
800   TF_RET_CHECK(ShapeUtil::IsScalar(condition->root_instruction()->shape()) &&
801                condition->root_instruction()->shape().element_type() == PRED)
802       << "While condition computation must return bool";
803   // Build ForThunk for conformant while loops, otherwise build WhileThunk.
804   auto config = xla_while->backend_config<WhileLoopBackendConfig>();
805   if (config.ok() && config.ValueOrDie().has_known_trip_count()) {
806     AddThunkToThunkSequence(
807         BuildForThunk(xla_while, config.ValueOrDie().known_trip_count().n()));
808   } else {
809     AddThunkToThunkSequence(BuildWhileThunk(xla_while));
810   }
811   return Status::OK();
812 }
813 
HandleRng(HloInstruction * rng)814 Status IrEmitterUnnested::HandleRng(HloInstruction* rng) {
815   return Unimplemented("Rng should be expanded for GPU.");
816 }
817 
HandleRngGetAndUpdateState(HloInstruction * rng_state)818 Status IrEmitterUnnested::HandleRngGetAndUpdateState(
819     HloInstruction* rng_state) {
820   // Emit a kernel to increment the global state for Philox RNG algorithm.
821   AddThunkToThunkSequence(
822       BuildKernelThunk(rng_state, /*implements_whole_instruction=*/true));
823 
824   llvm::Value* old_state = llvm_ir::RngGetAndUpdateState(
825       Cast<HloRngGetAndUpdateStateInstruction>(rng_state)->delta(), module_,
826       &b_);
827 
828   llvm::Value* output_address =
829       GetIrArray(*rng_state, *rng_state)
830           .EmitArrayElementAddress(
831               llvm_ir::IrArray::Index(
832                   /*linear=*/b_.getInt64(0), rng_state->shape(), &b_),
833               &b_, "rng_state_address");
834   output_address = BitCast(
835       output_address, llvm::PointerType::get(
836                           old_state->getType(),
837                           output_address->getType()->getPointerAddressSpace()));
838   Store(old_state, output_address);
839 
840   return Status::OK();
841 }
842 
HandleScatter(HloInstruction * scatter)843 Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) {
844   const HloInstruction* operand = scatter->operand(0);
845   const HloInstruction* scatter_indices = scatter->operand(1);
846   const HloInstruction* updates = scatter->operand(2);
847   std::vector<std::unique_ptr<Thunk>> thunks;
848 
849   // Copy the operand into the output if it's not the same buffer already.
850   auto operand_buffer = GetAllocationSlice(*operand);
851   auto destination_buffer = GetAllocationSlice(*scatter);
852   if (operand_buffer != destination_buffer) {
853     thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
854         /*source_address=*/operand_buffer,
855         /*destination_buffer=*/destination_buffer,
856         /*mem_size=*/ShapeUtil::ByteSizeOf(operand->shape()),
857         /*hlo_instruction=*/nullptr));
858   }
859 
860   thunks.push_back(
861       BuildKernelThunk(scatter,
862                        /*implements_whole_instruction=*/thunks.empty()));
863 
864   TF_RETURN_IF_ERROR(EmitScatter(
865       thunks.back().get(), scatter,
866       /*scatter_indices_gen=*/
867       [=](const IrArray::Index& index) {
868         return GetIrArray(*scatter_indices, *scatter)
869             .EmitReadArrayElement(index, &b_, "scatter_index");
870       },
871       /*updates_gen=*/
872       [=](const IrArray::Index& index) {
873         return GetIrArray(*updates, *scatter)
874             .EmitReadArrayElement(index, &b_, "update");
875       }));
876 
877   // Elide the sequential thunk if there's no copy.
878   if (thunks.size() == 1) {
879     AddThunkToThunkSequence(std::move(thunks[0]));
880   } else {
881     AddThunkToThunkSequence(
882         absl::make_unique<SequentialThunk>(std::move(thunks), scatter));
883   }
884 
885   return Status::OK();
886 }
887 
EmitScatter(Thunk * thunk,HloInstruction * scatter,const llvm_ir::ElementGenerator & scatter_indices_gen,const llvm_ir::ElementGenerator & updates_gen)888 Status IrEmitterUnnested::EmitScatter(
889     Thunk* thunk, HloInstruction* scatter,
890     const llvm_ir::ElementGenerator& scatter_indices_gen,
891     const llvm_ir::ElementGenerator& updates_gen) {
892   const HloInstruction* operand = scatter->operand(0);
893   const HloInstruction* scatter_indices = scatter->operand(1);
894   const HloInstruction* updates = scatter->operand(2);
895   const ScatterDimensionNumbers& dim_numbers =
896       scatter->scatter_dimension_numbers();
897   CHECK(ShapeUtil::Equal(scatter->shape(), operand->shape()));
898 
899   auto loop_body_emitter = [&](const IrArray::Index& index) -> Status {
900     std::vector<llvm::Value*> raw_window_multidim;
901     std::vector<llvm::Value*> input_scatter_multidim;
902     std::vector<int64> raw_window_bounds;
903 
904     // Partition the index into window indices and scatter indices.
905     for (int64 i = 0, e = index.size(); i != e; ++i) {
906       // For window indices also remember the window size, this comes in handy
907       // later.
908       if (absl::c_binary_search(dim_numbers.update_window_dims(), i)) {
909         raw_window_multidim.push_back(index[i]);
910         raw_window_bounds.push_back(updates->shape().dimensions(i));
911       } else {
912         input_scatter_multidim.push_back(index[i]);
913       }
914     }
915     DCHECK_EQ(raw_window_multidim.size(),
916               dim_numbers.update_window_dims_size());
917 
918     // Apply inserted_window_dims to the window dimensions.
919     int64 raw_window_multidim_idx = 0;
920     std::vector<llvm::Value*> input_window_multidim;
921     std::vector<int64> input_window_bounds;
922     for (int64 i = 0, e = operand->shape().rank(); i != e; ++i) {
923       if (absl::c_binary_search(dim_numbers.inserted_window_dims(), i)) {
924         input_window_bounds.push_back(1);  // Trivial dimension.
925         input_window_multidim.push_back(index.GetConstantWithIndexType(0));
926       } else {
927         input_window_bounds.push_back(
928             raw_window_bounds[raw_window_multidim_idx]);
929         input_window_multidim.push_back(
930             raw_window_multidim[raw_window_multidim_idx]);
931         ++raw_window_multidim_idx;
932       }
933     }
934     DCHECK_EQ(input_window_multidim.size(), operand->shape().rank());
935 
936     // Insert a 1 dimension at the end if index_vector_dim requests one.
937     Shape scatter_indices_shape = scatter_indices->shape();
938     if (dim_numbers.index_vector_dim() == scatter_indices_shape.rank()) {
939       scatter_indices_shape.add_dimensions(1);
940       scatter_indices_shape.mutable_layout()->add_minor_to_major(
941           dim_numbers.index_vector_dim());
942     }
943 
944     // Now load the indices corresponding to the current window from
945     // scatter_indices.
946     std::vector<llvm::Value*> raw_scatter_index_multidim =
947         input_scatter_multidim;
948     raw_scatter_index_multidim.insert(
949         raw_scatter_index_multidim.begin() + dim_numbers.index_vector_dim(),
950         nullptr);
951     llvm::Value* is_in_bounds = b_.getTrue();
952     for (int64 i = 0, e = dim_numbers.scatter_dims_to_operand_dims_size();
953          i != e; ++i) {
954       // Our index is stored along index_vector_dim, insert that into the lookup
955       // index into scatter_indices.
956       raw_scatter_index_multidim[dim_numbers.index_vector_dim()] =
957           index.GetConstantWithIndexType(i);
958       llvm_ir::IrArray::Index raw_scatter_index_index(
959           raw_scatter_index_multidim, scatter_indices_shape, index.GetType());
960 
961       int64 operand_dim = dim_numbers.scatter_dims_to_operand_dims(i);
962       TF_ASSIGN_OR_RETURN(
963           llvm::Value* const loaded_scatter_index,
964           scatter_indices_gen(raw_scatter_index_index.SourceIndexOfReshape(
965               scatter_indices_shape, scatter_indices->shape(), &b_)));
966       // And add the index to our window index. This yields the output index.
967       llvm::Value* casted_scatter_index =
968           IntCast(loaded_scatter_index, index.GetType(),
969                   /*isSigned=*/true);
970       llvm::Value* dim_offset =
971           Add(input_window_multidim[operand_dim], casted_scatter_index);
972       input_window_multidim[operand_dim] = dim_offset;
973 
974       // Also do the bounds check now.
975       int64 max_index = operand->shape().dimensions(operand_dim) -
976                         input_window_bounds[operand_dim] + 1;
977       // is_in_bounds = index >= 0 && index < dim_size-window_size+1
978       //   --> index u< dim_size-window_size+1
979       is_in_bounds =
980           And(is_in_bounds, ICmpULT(casted_scatter_index,
981                                     index.GetConstantWithIndexType(max_index)));
982     }
983 
984     llvm_ir::LlvmIfData if_window_in_bounds_data = llvm_ir::EmitIfThenElse(
985         is_in_bounds, "scatter.in_bounds", &b_, /*emit_else=*/false);
986     llvm_ir::SetToFirstInsertPoint(if_window_in_bounds_data.true_block, &b_);
987     // All done, now just read from the calculated input from the window, and do
988     // an atomic store to the calculated location in the output.
989     HloInstruction* output_hlo =
990         scatter->IsFused() ? scatter->parent()->FusionInstruction() : scatter;
991     llvm_ir::IrArray::Index input_window_index(
992         input_window_multidim, output_hlo->shape(), index.GetType());
993     llvm::Value* output_address =
994         GetIrArray(*output_hlo, *output_hlo)
995             .EmitArrayElementAddress(input_window_index, &b_);
996     llvm::Value* input_address = Alloca(llvm_ir::PrimitiveTypeToIrType(
997         updates->shape().element_type(), module_));
998     TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value, updates_gen(index));
999     Store(input_ir_value, input_address);
1000 
1001     if (!scatter->unique_indices()) {
1002       return EmitAtomicOperationForNestedComputation(
1003           *scatter->to_apply(), output_address, input_address);
1004     } else {
1005       return EmitCallToNestedComputation(*scatter->to_apply(),
1006                                          {output_address, input_address},
1007                                          output_address);
1008     }
1009   };
1010 
1011   // Launch a kernel that reads every element in the updates tensor. We could
1012   // also do one kernel per window instead if bounds checks turn out to be a
1013   // bottleneck.
1014   LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
1015       updates->shape(), ir_emitter_context_->device_description());
1016   UpdateLaunchDimensions(launch_dimensions, thunk,
1017                          ir_emitter_context_->llvm_module());
1018 
1019   return ParallelLoopEmitter(loop_body_emitter, updates->shape(),
1020                              launch_dimensions, &b_)
1021       .EmitLoop(IrName(scatter),
1022                 GetIndexTypeForKernel(scatter, launch_dimensions.launch_bound(),
1023                                       &b_));
1024 }
1025 
HandleSelect(HloInstruction * select)1026 Status IrEmitterUnnested::HandleSelect(HloInstruction* select) {
1027   return IrEmitter::HandleSelect(select);
1028 }
1029 
HandleSort(HloInstruction * sort)1030 Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
1031   std::vector<std::unique_ptr<Thunk>> thunks;
1032   Shape keys_shape = sort->operand(0)->shape();
1033   int64 dimension_to_sort = sort->dimensions(0);
1034   for (int64 i = 0; i < sort->operand_count(); ++i) {
1035     ShapeIndex shape_index =
1036         sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({});
1037     // We assume that the layout of all involved operands and outputs is the
1038     // same.
1039     TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(keys_shape,
1040                                                   sort->operand(i)->shape()));
1041     TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
1042         keys_shape, ShapeUtil::GetSubshape(sort->shape(), shape_index)));
1043 
1044     // If possible, we share buffers. If that is not possible, we need to copy
1045     // the values, because the emitter does the sorting in-place.
1046     auto destination_buffer = GetAllocationSlice(*sort, shape_index);
1047     auto source_address = GetAllocationSlice(*sort->operand(i));
1048     if (destination_buffer != source_address) {
1049       // TODO(b/26783907): Figure out why we never seem to share buffers for
1050       // key/value sort.
1051       thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
1052           /*source_address=*/source_address,
1053           /*destination_buffer=*/destination_buffer,
1054           /*mem_size=*/ShapeUtil::ByteSizeOf(sort->operand(i)->shape()),
1055           nullptr));
1056     }
1057   }
1058 
1059   uint64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort);
1060   int64 num_stages = tensorflow::Log2Ceiling(dimension_to_sort_bound);
1061   CHECK_GE(1ULL << num_stages, dimension_to_sort_bound);
1062   CHECK_LT(1ULL << (num_stages - 1), dimension_to_sort_bound);
1063 
1064   // Naive C++ code for the outer loops:
1065   //
1066   // for (int64 stage = 0; stage < Log2Ceiling(dimension_to_sort_bound);
1067   //     ++stage) {
1068   //   int64 first_xor_mask = (1LL << (stage + 1)) - 1;
1069   //   SortInPlace(first_xor_mask);
1070   //   for (int64 mask = stage - 1; mask >= 0; --mask) {
1071   //     int64 later_xor_mask = 1LL << mask;
1072   //     SortInPlace(later_xor_mask);
1073   //   }
1074   // }
1075   //
1076   // This follows the alternative representation of the algorithm described on
1077   // Wikipedia: https://en.wikipedia.org/wiki/Bitonic_sorter
1078   //
1079   // Each mask specifies how to derive from one position in the array the
1080   // position with which it should be compared (we calculate the xor of the
1081   // position with the mask).
1082   // As an optimization, we can move the 'mask' loop to inside the
1083   // sorting/comparison loop if the comparisons happen within a small block of
1084   // the array. To make this work, we collect all consecutive masks that are
1085   // smaller than our chosen power of 2 tile size, and pass them to SortInPlace.
1086   // Each thread then processes one tile of data.
1087 
1088   const uint64 kTileSize = std::min(2048ULL, 1ULL << num_stages);
1089 
1090   // If we cannot combine several xor masks together, we don't use tiling, so we
1091   // calculate the standard launch dimensions for the shape. However we only
1092   // need to iterate through ~half of the dimension to sort (rounded up to the
1093   // next highest power of 2), because each iteration compares one pair of
1094   // elements.
1095   Shape standard_iteration_shape = keys_shape;
1096   uint64 standard_num_iterations_in_sort_dim = 1ULL << (num_stages - 1);
1097   standard_iteration_shape.set_dimensions(dimension_to_sort,
1098                                           standard_num_iterations_in_sort_dim);
1099   LaunchDimensions standard_launch_dimensions = CalculateLaunchDimensions(
1100       standard_iteration_shape, ir_emitter_context_->device_description());
1101 
1102   // Calculate the launch dimensions for the case where we use tiling. We split
1103   // the dimension that should be sorted into tiles of size 'kTileSize'. This
1104   // means we first need to round 'dimension_to_sort_bound' up to be a multiple
1105   // of the tile size.
1106   int64 rounded_bound = RoundUpToNearest(dimension_to_sort_bound, kTileSize);
1107   Shape iteration_shape = keys_shape;
1108 
1109   // We iterate through the element pairs that should be compared.
1110   uint64 num_iterations_in_sort_dim = rounded_bound / 2;
1111   iteration_shape.set_dimensions(dimension_to_sort, num_iterations_in_sort_dim);
1112   uint64 num_iterations = ShapeUtil::ElementsIn(iteration_shape);
1113 
1114   // For correctness reasons we need exactly 'kTileSize' / 2 many threads per
1115   // block. Each thread is responsible for copying exactly two adjacent elements
1116   // into shared memory, and then does a comparison of two possibly different
1117   // elements taken from shared memory.
1118   const uint64 kThreadsPerBlock = kTileSize / 2;
1119 
1120   // Check whether we should use any tiling. We might not be able to use it if
1121   // we have not enough threads, or not enough shared memory. Also it does not
1122   // give a speedup if the tile size is < 128.
1123   int64 total_shared_memory_needed = 0;
1124   for (int64 i = 0; i < sort->operand_count(); ++i) {
1125     total_shared_memory_needed +=
1126         kTileSize * ShapeUtil::ByteSizeOfPrimitiveType(
1127                         sort->operand(i)->shape().element_type());
1128   }
1129   bool no_tiling =
1130       kTileSize < 128 ||
1131       kThreadsPerBlock >
1132           ir_emitter_context_->device_description().threads_per_block_limit() ||
1133       total_shared_memory_needed >
1134           ir_emitter_context_->device_description().shared_memory_per_block();
1135 
1136   uint64 num_blocks = CeilOfRatio(num_iterations, kThreadsPerBlock);
1137   LaunchDimensions tiled_launch_dimensions(num_blocks, kThreadsPerBlock);
1138 
1139   auto emit_kernel = [&](absl::Span<const int64> xor_masks) {
1140     thunks.push_back(
1141         BuildKernelThunk(sort, /*implements_whole_instruction=*/false));
1142     LaunchDimensions launch_dimensions = xor_masks.size() > 1
1143                                              ? tiled_launch_dimensions
1144                                              : standard_launch_dimensions;
1145     UpdateLaunchDimensions(launch_dimensions, thunks.back().get(),
1146                            ir_emitter_context_->llvm_module());
1147     std::vector<IrArray> values_arrays;
1148     values_arrays.reserve(sort->operand_count());
1149     for (int64 i = 0; i < sort->operand_count(); ++i) {
1150       ShapeIndex shape_index =
1151           sort->operand_count() > 1 ? ShapeIndex({i}) : ShapeIndex({});
1152       values_arrays.push_back(GetIrArray(*sort, *sort, shape_index));
1153     }
1154     return llvm_ir::EmitSortInPlace(
1155         dimension_to_sort, values_arrays, IrName(sort), xor_masks, &b_,
1156         launch_dimensions,
1157         xor_masks.size() > 1 ? num_iterations_in_sort_dim
1158                              : standard_num_iterations_in_sort_dim,
1159         kTileSize,
1160         [&](absl::Span<llvm::Value* const> operands, llvm::Value* output) {
1161           return EmitCallToNestedComputation(*sort->to_apply(), operands,
1162                                              output);
1163         });
1164   };
1165   std::vector<int64> xor_masks;
1166   for (int64 stage = 0; stage < num_stages; ++stage) {
1167     for (int64 mask = stage; mask >= 0; --mask) {
1168       int64 xor_mask;
1169       if (mask == stage) {
1170         xor_mask = (1LL << (stage + 1)) - 1;
1171       } else {
1172         xor_mask = 1LL << mask;
1173       }
1174       if (xor_mask >= kTileSize || no_tiling) {
1175         if (!xor_masks.empty()) {
1176           TF_RETURN_IF_ERROR(emit_kernel(xor_masks));
1177           xor_masks.clear();
1178         }
1179         TF_RETURN_IF_ERROR(emit_kernel({xor_mask}));
1180       } else {
1181         xor_masks.push_back(xor_mask);
1182       }
1183     }
1184   }
1185   if (!xor_masks.empty()) {
1186     TF_RETURN_IF_ERROR(emit_kernel(xor_masks));
1187   }
1188 
1189   AddThunkToThunkSequence(
1190       absl::make_unique<SequentialThunk>(std::move(thunks), sort));
1191   return Status::OK();
1192 }
1193 
HandleTupleSelect(HloInstruction * tuple_select)1194 Status IrEmitterUnnested::HandleTupleSelect(HloInstruction* tuple_select) {
1195   AddThunkToThunkSequence(
1196       BuildKernelThunk(tuple_select, /*implements_whole_instruction=*/true));
1197   return IrEmitter::HandleTupleSelect(tuple_select);
1198 }
1199 
HandleReplicaId(HloInstruction * hlo)1200 Status IrEmitterUnnested::HandleReplicaId(HloInstruction* hlo) {
1201   AddThunkToThunkSequence(
1202       absl::make_unique<ReplicaIdThunk>(GetAllocationSlice(*hlo), hlo));
1203   return Status::OK();
1204 }
1205 
HandleCollectivePermute(HloInstruction * hlo)1206 Status IrEmitterUnnested::HandleCollectivePermute(HloInstruction* hlo) {
1207   AddThunkToThunkSequence(absl::make_unique<CollectivePermuteThunk>(
1208       GetAllocationSlice(*hlo->operand(0)), GetAllocationSlice(*hlo), hlo));
1209   return Status::OK();
1210 }
1211 
1212 namespace {
1213 
1214 
1215 }  // namespace
1216 
HandleAllReduce(HloInstruction * crs)1217 Status IrEmitterUnnested::HandleAllReduce(HloInstruction* crs) {
1218   VLOG(2) << "AllReduce; replica count: " << hlo_module_config_.replica_count()
1219           << "; operand count: " << crs->operand_count()
1220           << "; NCCL is enabled: " << NcclAllReduceThunk::NcclIsEnabled();
1221 
1222   // Note the replica_count == 1 case is handled via device-to-device copy
1223   // below.
1224   bool should_use_nccl_thunk = hlo_module_config_.replica_count() > 1 &&
1225                                NcclAllReduceThunk::CanImplement(crs);
1226 
1227   if (should_use_nccl_thunk) {
1228     CHECK(crs->operand(0)->shape().IsArray())
1229         << "Operands to all-reduce must be arrays: " << crs->ToString();
1230     AddThunkToThunkSequence(absl::make_unique<NcclAllReduceThunk>(
1231         /*replica_count=*/hlo_module_config_.replica_count(),
1232         /*elements=*/ShapeUtil::ElementsIn(crs->operand(0)->shape()),
1233         /*source_address=*/GetAllocationSlice(*crs->operand(0)),
1234         /*destination_buffer=*/GetAllocationSlice(*crs), crs));
1235     return Status::OK();
1236   }
1237 
1238   if (hlo_module_config_.replica_count() != 1) {
1239     // TODO(b/33011107): Support more AllReduce configurations on GPU.
1240     string message = absl::StrFormat(
1241         "Requested AllReduce not implemented on GPU; replica_count: %d; "
1242         "operand_count: %d; IsCrossReplicaAllReduce: %d; NCCL support: %d",
1243         hlo_module_config_.replica_count(), crs->operand_count(),
1244         crs->IsCrossReplicaAllReduce(), NcclAllReduceThunk::NcclIsEnabled());
1245     if (crs->operand_count() > 0) {
1246       absl::StrAppendFormat(
1247           &message, "; first operand array element-type: %s",
1248           PrimitiveType_Name(crs->operand(0)->shape().element_type()));
1249     }
1250     return Unimplemented("%s", message);
1251   }
1252 
1253   // CRS with one operand and one replica is simply the identity function.
1254   // Buffer assignment expects a copy, so that's what we do.
1255   //
1256   // TODO(b/80100934): We would like to eliminate one-replica CRS nodes entirely
1257   // in algebraic-simplifier, but currently on some platforms
1258   // HloModuleConfig::num_replicas changes between when the module is compiled
1259   // and when it's run.
1260   if (crs->operand_count() == 1) {
1261     CHECK(crs->operand(0)->shape().IsArray())
1262         << "Operands to all-reduce must be arrays: " << crs->ToString();
1263     AddThunkToThunkSequence(absl::make_unique<DeviceToDeviceCopyThunk>(
1264         /*source_address=*/GetAllocationSlice(*crs->operand(0)),
1265         /*destination_buffer=*/GetAllocationSlice(*crs),
1266         /*mem_size=*/ShapeUtil::ByteSizeOf(crs->shape()), crs));
1267     return Status::OK();
1268   }
1269 
1270   // One-replica CRS with multiple operands produces a tuple of the inputs.
1271   // Again, buffer assignment expects us to copy each.
1272   std::vector<std::unique_ptr<Thunk>> thunks;
1273   std::vector<BufferAllocation::Slice> tuple_element_buffers;
1274   for (int64 i = 0; i < crs->operand_count(); ++i) {
1275     tuple_element_buffers.push_back(ir_emitter_context_->buffer_assignment()
1276                                         .GetUniqueSlice(crs, {i})
1277                                         .ValueOrDie());
1278     thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
1279         /*source_address=*/GetAllocationSlice(*crs->operand(i)),
1280         /*destination_buffer=*/tuple_element_buffers.back(),
1281         /*mem_size=*/ShapeUtil::ByteSizeOf(crs->operand(i)->shape()), nullptr));
1282   }
1283 
1284   // Output a tuple of the buffers above.
1285   thunks.push_back(absl::make_unique<TupleThunk>(
1286       tuple_element_buffers, GetAllocationSlice(*crs), nullptr));
1287   AddThunkToThunkSequence(
1288       absl::make_unique<SequentialThunk>(std::move(thunks), crs));
1289   return Status::OK();
1290 }
1291 
HandleInfeed(HloInstruction * xla_infeed)1292 Status IrEmitterUnnested::HandleInfeed(HloInstruction* xla_infeed) {
1293   return ThunkEmitter(this).HandleInfeed(xla_infeed);
1294 }
1295 
HandleOutfeed(HloInstruction * outfeed)1296 Status IrEmitterUnnested::HandleOutfeed(HloInstruction* outfeed) {
1297   return ThunkEmitter(this).HandleOutfeed(outfeed);
1298 }
1299 
HandleAfterAll(HloInstruction * after_all)1300 Status IrEmitterUnnested::HandleAfterAll(HloInstruction* after_all) {
1301   return Status::OK();
1302 }
1303 
1304 // Figures out how to access the buffers for all subshapes of hlo's operands and
1305 // for hlo itself (i.e. all the buffers produced by HLO).
1306 //
1307 // Returns a map keyed on the pair {HloInstruction, ShapeIndex}.  The value for
1308 // this key is a pair {Slice, ShapeIndex}, where the slice tells you the root
1309 // buffer to look in, and the ShapeIndex describes how to dereference starting
1310 // at that buffer to get to the buffer in question.
1311 //
1312 // For example, if {hlo, {1}} is mapped to {slice, {3, 4}}, then the buffer for
1313 // hlo at ShapeIndex {1} (i.e. the buffer for the second tuple element of hlo)
1314 // is found at slice[3][4].  That is, slice is a void***, which we dereference
1315 // twice -- first at index 3, and then at index 4 -- to get the address of our
1316 // buffer.
1317 //
1318 // This function conservatively assumes that we'll touch all sub-buffers of
1319 // every operand and of the output.
1320 static std::map<std::pair<const HloInstruction*, ShapeIndex>,
1321                 std::pair<BufferAllocation::Slice, ShapeIndex>>
GetHloBufferSlices(const HloInstruction * hlo,const BufferAssignment & buffer_assn)1322 GetHloBufferSlices(const HloInstruction* hlo,
1323                    const BufferAssignment& buffer_assn) {
1324   std::map<std::pair<const HloInstruction*, ShapeIndex>,
1325            std::pair<BufferAllocation::Slice, ShapeIndex>>
1326       slices;
1327 
1328   // Tries to find a slice plus an array of indices i1, ..., iN such that the
1329   // sub-buffer for instr at index can be found at slice[i1]...[iN].
1330   auto find_slice_for = [&](const HloInstruction* instr,
1331                             const ShapeIndex& index)
1332       -> optional<std::pair<BufferAllocation::Slice, ShapeIndex>> {
1333     // Simple, common case: Is the buffer for instr known at runtime?  If so,
1334     // we're done.
1335     auto slice = buffer_assn.GetUniqueSlice(instr, index);
1336     if (slice.ok()) {
1337       return {{slice.ValueOrDie(), ShapeIndex()}};
1338     }
1339 
1340     // If that didn't work, walk up any bitcasts that we might see.  These must
1341     // appear before any GTE instructions, because it's illegal to bitcast to a
1342     // tuple type.
1343     const HloInstruction* parent = instr;
1344     while (parent->opcode() == HloOpcode::kBitcast) {
1345       parent = parent->operand(0);
1346 
1347       auto slice = buffer_assn.GetUniqueSlice(parent, {});
1348       if (slice.ok()) {
1349         return {{slice.ValueOrDie(), ShapeIndex()}};
1350       }
1351     }
1352 
1353     // Check whether instr is a GTE instruction.  If it is, see if we can get a
1354     // buffer for its parent, and continue walking up parents until we find a
1355     // defined buffer or we hit something that's not a GTE.
1356     ShapeIndex gte_indices;
1357     while (parent->opcode() == HloOpcode::kGetTupleElement) {
1358       gte_indices.push_front(parent->tuple_index());
1359       parent = parent->operand(0);
1360 
1361       auto slice = buffer_assn.GetUniqueSlice(parent, {});
1362       if (slice.ok()) {
1363         return {{slice.ValueOrDie(), gte_indices}};
1364       }
1365     }
1366 
1367     // Finally, if we don't know the buffer for instr at index, see if we know
1368     // the buffer for instr at index without its last element.  If so, we can
1369     // dynamically find the buffer for instr by dereferencing a pointer in that
1370     // buffer.  Continue looking this way until we run out of elements in
1371     // 'index'.
1372     //
1373     // We can almost always get a buffer without resorting to this.  The only
1374     // exception is for cases where the relevant sub-buffer is truly unknowable,
1375     // for example the sub-buffer of a tuple-shaped select.
1376     ShapeIndex new_index = index;
1377     while (!new_index.empty()) {
1378       gte_indices.push_front(new_index.back());
1379       new_index.pop_back();
1380       auto slice = buffer_assn.GetUniqueSlice(instr, new_index);
1381       if (slice.ok()) {
1382         return {{slice.ValueOrDie(), gte_indices}};
1383       }
1384     }
1385 
1386     return nullopt;
1387   };
1388 
1389   // Adds entries for all subshapes of instr to `slices`.
1390   auto add_slices_for = [&](const HloInstruction* instr) {
1391     ShapeUtil::ForEachSubshape(
1392         instr->shape(), [&](const Shape& /*shape*/, const ShapeIndex& index) {
1393           if (slices.count({instr, index})) {
1394             // HLOs can have duplicate operands; don't bother redoing work.
1395             return;
1396           }
1397           auto maybe_slice = find_slice_for(instr, index);
1398           if (maybe_slice.has_value()) {
1399             slices[{instr, index}] = *maybe_slice;
1400           } else {
1401             VLOG(1) << "Couldn't find buffer for " << instr->ToString()
1402                     << " at index " << index.ToString();
1403           }
1404         });
1405   };
1406 
1407   add_slices_for(hlo);
1408   for (const HloInstruction* operand : hlo->operands()) {
1409     // Conservatively assume we'll need the buffers for all subshapes of the
1410     // operand.
1411     add_slices_for(operand);
1412   }
1413 
1414   return slices;
1415 }
1416 
BuildKernelThunk(const HloInstruction * inst,bool implements_whole_instruction,int unroll_factor)1417 std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
1418     const HloInstruction* inst, bool implements_whole_instruction,
1419     int unroll_factor) {
1420   const BufferAssignment& buffer_assn =
1421       ir_emitter_context_->buffer_assignment();
1422 
1423   std::map<std::pair<const HloInstruction*, ShapeIndex>,
1424            std::pair<BufferAllocation::Slice, ShapeIndex>>
1425       hlo_slices = GetHloBufferSlices(inst, buffer_assn);
1426 
1427   // Figure out which buffer allocations need to be passed as arguments to our
1428   // kernel.  This is simply all of the allocations referenced in hlo_slices,
1429   // plus the XLA temp buffer (if we have it).  We always include the temp
1430   // buffer because even if the kernel itself doesn't use it, a nested
1431   // subcomputation within the kernel (e.g. a kMap's computation) might.
1432   std::unordered_set<const BufferAllocation*> buffers_needed;
1433   for (const auto& kv : hlo_slices) {
1434     buffers_needed.insert(kv.second.first.allocation());
1435   }
1436   absl::optional<const BufferAllocation*> temp_buffer;
1437   for (const BufferAllocation& alloc : buffer_assn.Allocations()) {
1438     if (alloc.IsPreallocatedTempBuffer()) {
1439       if (!temp_buffer.has_value()) {
1440         temp_buffer = &alloc;
1441       } else {
1442         LOG(FATAL) << "Multiple temp buffers found, but only one is allowed!";
1443       }
1444     }
1445   }
1446   if (temp_buffer.has_value()) {
1447     buffers_needed.insert(*temp_buffer);
1448   }
1449 
1450   // We'll pass a pointer to each of the elements of `buffers` to our kernel, in
1451   // this order.
1452   std::vector<const BufferAllocation*> non_constant_buffers;
1453   absl::c_copy_if(buffers_needed, std::back_inserter(non_constant_buffers),
1454                   [](const BufferAllocation* allocation) {
1455                     return !allocation->is_constant();
1456                   });
1457 
1458   absl::c_sort(non_constant_buffers,
1459                [](const BufferAllocation* a, const BufferAllocation* b) {
1460                  return a->index() < b->index();
1461                });
1462 
1463   llvm::Function* kernel = BuildKernelPrototype(*inst, non_constant_buffers);
1464 
1465   // Build a map from a BufferAllocation to the corresponding argument in our
1466   // kernel.
1467   std::unordered_map<const BufferAllocation*, llvm::Value*> kernel_args;
1468   {
1469     auto arg_it = kernel->arg_begin();
1470     auto buffers_it = non_constant_buffers.begin();
1471     for (; arg_it != kernel->arg_end(); ++arg_it, ++buffers_it) {
1472       kernel_args[*buffers_it] = arg_it;
1473     }
1474   }
1475 
1476   // For each buffer our kernel might want to touch, bind it to a value derived
1477   // from our kernel args.
1478   for (const auto& kv : hlo_slices) {
1479     const HloInstruction* instr = kv.first.first;
1480     const ShapeIndex& index = kv.first.second;
1481     const BufferAllocation::Slice& slice = kv.second.first;
1482     const ShapeIndex& gte_index = kv.second.second;
1483 
1484     VLOG(3) << "Buffer for " << instr->ToString() << " at " << index.ToString()
1485             << " is found in slice " << slice.ToString() << " at GTE index "
1486             << gte_index.ToString();
1487 
1488     llvm::Value* loc;
1489     if (slice.allocation()->is_constant()) {
1490       loc = ir_emitter_context_->llvm_module()->getGlobalVariable(
1491           llvm_ir::ConstantBufferAllocationToGlobalName(*slice.allocation()));
1492       CHECK_NE(loc, nullptr);
1493     } else {
1494       loc = InBoundsGEP(kernel_args.at(slice.allocation()),
1495                         {b_.getInt64(slice.offset())});
1496     }
1497 
1498     // If gte_index is nonempty, we have to dereference `loc` to get to the
1499     // value we're ultimately interested in.
1500     llvm::Type* int8_double_pointer =
1501         llvm::PointerType::get(b_.getInt8PtrTy(), /*AddressSpace=*/0);
1502     for (int64 idx : gte_index) {
1503       loc = b_.CreatePointerBitCastOrAddrSpaceCast(loc, int8_double_pointer);
1504       loc = Load(InBoundsGEP(loc, {b_.getInt64(idx)}));
1505     }
1506 
1507     bindings_.BindHloToIrValue(*instr, loc, index);
1508   }
1509 
1510   // Bind the temp buffer so that nested subcomputations can find it if they
1511   // need.
1512   if (temp_buffer.has_value()) {
1513     bindings_.SetTempBufferBase(kernel_args.at(*temp_buffer));
1514   } else {
1515     bindings_.SetTempBufferBase(
1516         llvm::ConstantPointerNull::get(b_.getInt8PtrTy()));
1517   }
1518 
1519   return absl::make_unique<KernelThunk>(
1520       non_constant_buffers, std::string(kernel->getName()),
1521       implements_whole_instruction ? inst : nullptr, unroll_factor);
1522 }
1523 
BuildInitializerThunk(HloInstruction * hlo,const ShapeIndex & index)1524 StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
1525     HloInstruction* hlo, const ShapeIndex& index) {
1526   bool fused = HloOpcode::kFusion == hlo->opcode();
1527   HloInstruction* inst = fused ? hlo->fused_expression_root() : hlo;
1528   HloInstruction* init_value_operand = [&] {
1529     switch (inst->opcode()) {
1530       case HloOpcode::kSelectAndScatter:
1531         return inst->mutable_operand(2);
1532       case HloOpcode::kReduce:
1533         return inst->mutable_operand(1);
1534       case HloOpcode::kTuple:
1535         CHECK(hlo->IsMultiOutputFusion())
1536             << ": " << hlo->ToString() << " is not a multi-output fusion.";
1537         CHECK(inst->operand(index.back())->opcode() == HloOpcode::kReduce)
1538             << ": Found '" << inst->operand(index.back())->opcode() << "' in "
1539             << inst->ToString() << " but expected 'reduce'.";
1540         // For multi-output fusion look through the tuple.
1541         return inst->mutable_operand(index.back())->mutable_operand(1);
1542       default:
1543         LOG(FATAL) << "Opcode " << inst->opcode()
1544                    << " should not need an initializer.";
1545     }
1546   }();
1547 
1548   const HloInstruction* init_value = init_value_operand;
1549   if (fused && init_value->opcode() == HloOpcode::kParameter) {
1550     init_value = hlo->operand(init_value->parameter_number());
1551   }
1552 
1553   // Initializer thunks don't implement a whole instruction, and we want to
1554   // profile the whole instruction instead of the individual thunks it consists
1555   // of. Therefore we pass nullptr as the HloInstruction* to the thunks we
1556   // generate below.
1557   //
1558   // In the common case, the initializer is a constant.  In this case, emit a
1559   // device-memset call if we can.  Currently StreamExecutor only supports
1560   // zeroing and 32-bit memsets.
1561   if (init_value->IsConstant()) {
1562     CHECK(ShapeUtil::IsScalar(init_value->shape()));
1563     int64 num_bytes = ShapeUtil::ByteSizeOfElements(init_value->shape());
1564     const auto& literal = init_value->literal();
1565 
1566     // Are all the bytes of this scalar equal to 0?  If so, we can create a
1567     // MemzeroThunk.
1568     absl::Span<const uint8> literal_bytes(
1569         reinterpret_cast<const uint8*>(literal.untyped_data()), num_bytes);
1570     if (absl::c_all_of(literal_bytes, [](uint8 byte) { return byte == 0; })) {
1571       return {absl::make_unique<MemzeroThunk>(GetAllocationSlice(*hlo, index),
1572                                               nullptr)};
1573     }
1574 
1575     // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by
1576     // repeating the literal 4 or 2 times, so long as the destination buffer is
1577     // an even multiple of 32 bits long.
1578     const Shape& output_shape = ShapeUtil::GetSubshape(hlo->shape(), index);
1579     if ((num_bytes == 1 || num_bytes == 2) &&
1580         ShapeUtil::ByteSizeOf(output_shape) % 4 == 0) {
1581       uint16 pattern16;
1582       if (num_bytes == 1) {
1583         uint8 b = literal_bytes.front();
1584         pattern16 = uint16{b} | (uint16{b} << 8);
1585       } else {
1586         memcpy(&pattern16, literal_bytes.data(), sizeof(pattern16));
1587       }
1588       uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16);
1589       return {absl::make_unique<Memset32BitValueThunk>(
1590           pattern32, GetAllocationSlice(*hlo, index), nullptr)};
1591     }
1592 
1593     // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit
1594     // memset so long as all 32-bit words of the scalar are equal to each other.
1595     if (num_bytes >= 4 && num_bytes % 4 == 0 &&
1596         memcmp(literal_bytes.data(), literal_bytes.data() + 4,
1597                literal_bytes.size() - 4) == 0) {
1598       uint32 word;
1599       memcpy(&word, literal_bytes.data(), sizeof(word));
1600       return {absl::make_unique<Memset32BitValueThunk>(
1601           word, GetAllocationSlice(*hlo, index), nullptr)};
1602     }
1603   }
1604 
1605   // Otherwise fall back to our slow initializer code.
1606   std::unique_ptr<KernelThunk> kernel_thunk =
1607       BuildKernelThunk(hlo, /*implements_whole_instruction=*/false);
1608   LaunchDimensions launch_dimensions =
1609       CalculateLaunchDimensions(ShapeUtil::GetSubshape(hlo->shape(), index),
1610                                 ir_emitter_context_->device_description());
1611   UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
1612                          ir_emitter_context_->llvm_module());
1613 
1614   if (fused) {
1615     // If init_value was fused into this reduce we have to generate it first.
1616     GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
1617                                             ir_emitter_context_->llvm_module(),
1618                                             &b_, GetNestedComputer());
1619 
1620     FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(hlo),
1621                                  &elemental_emitter);
1622     TF_RETURN_IF_ERROR(init_value_operand->Accept(&fused_emitter));
1623     TF_RETURN_IF_ERROR(
1624         ParallelLoopEmitter(fused_emitter.GetGenerator(init_value_operand),
1625                             GetIrArray(*hlo, *hlo, index), launch_dimensions,
1626                             &b_)
1627             .EmitLoop(IrName(hlo)));
1628   } else {
1629     // In the unfused case the element is already there, just read from it.
1630     TF_RETURN_IF_ERROR(ParallelLoopEmitter(
1631                            [=](const IrArray::Index& index) {
1632                              return GetIrArray(*init_value, *hlo)
1633                                  .EmitReadArrayElement(index, &b_);
1634                            },
1635                            GetIrArray(*hlo, *hlo, index), launch_dimensions,
1636                            &b_)
1637                            .EmitLoop(IrName(hlo)));
1638   }
1639 
1640   // Clean up state left behind by emitting the loop above.  (This is normally
1641   // done in IrEmitterUnnested::Postprocess().)
1642   bindings_.UnbindAllLocalIrValues();
1643 
1644   // Convert unique_ptr<KernelThunk> to StatusOr<unique_ptr<Thunk>>.
1645   return {std::move(kernel_thunk)};
1646 }
1647 
1648 namespace {
1649 
1650 // Checks that the buffers corresponding to the given two HLOs share the same
1651 // allocation.
CheckHloBuffersShareAllocation(const HloInstruction * a,const HloInstruction * b,const ShapeIndex & index,const BufferAssignment & buffer_assignment)1652 Status CheckHloBuffersShareAllocation(
1653     const HloInstruction* a, const HloInstruction* b, const ShapeIndex& index,
1654     const BufferAssignment& buffer_assignment) {
1655   const BufferAllocation::Slice slice_a =
1656       buffer_assignment.GetUniqueSlice(a, index).ConsumeValueOrDie();
1657   const BufferAllocation::Slice slice_b =
1658       buffer_assignment.GetUniqueSlice(b, index).ConsumeValueOrDie();
1659   if (slice_a != slice_b) {
1660     return InternalError(
1661         "instruction %s %s does not share allocation with instruction %s %s",
1662         a->ToString(), slice_a.ToString(), b->ToString(), slice_b.ToString());
1663   }
1664   return Status::OK();
1665 }
1666 
1667 // Checks that all buffers used during while loop iteration share the same
1668 // buffer allocation. This includes buffers for while result, while init
1669 // operand, condition parameter, body parameter and body result.
1670 // Returns OK on success, error status otherwise.
CheckWhileBuffersShareAllocation(const HloInstruction * xla_while,const BufferAssignment & buffer_assignment)1671 Status CheckWhileBuffersShareAllocation(
1672     const HloInstruction* xla_while,
1673     const BufferAssignment& buffer_assignment) {
1674   return ShapeUtil::ForEachSubshapeWithStatus(
1675       xla_while->shape(),
1676       [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status {
1677         const HloInstruction* condition_parameter =
1678             xla_while->while_condition()->parameter_instruction(0);
1679         const HloComputation* body = xla_while->while_body();
1680         const HloInstruction* body_parameter = body->parameter_instruction(0);
1681         const HloInstruction* body_result = body->root_instruction();
1682         TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
1683             xla_while, xla_while->operand(0), index, buffer_assignment));
1684         TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
1685             xla_while, condition_parameter, index, buffer_assignment));
1686         TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
1687             xla_while, body_parameter, index, buffer_assignment));
1688         TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
1689             xla_while, body_result, index, buffer_assignment));
1690         return Status::OK();
1691       });
1692 }
1693 
1694 // Checks that the buffers used in a conditional instruction are shared with the
1695 // operands and result as follows:
1696 //   * The result buffer of the conditional should share the allocation with the
1697 //     result buffers of each branch computation.
1698 //   * The buffer of operand b+1 should share the allocation with the buffer of
1699 //     the parameter 0 instruction of the b'th computation.
CheckConditionalBuffersShareAllocation(const HloInstruction * conditional,const BufferAssignment & buffer_assignment)1700 Status CheckConditionalBuffersShareAllocation(
1701     const HloInstruction* conditional,
1702     const BufferAssignment& buffer_assignment) {
1703   TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
1704       conditional->shape(),
1705       [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status {
1706         for (auto branch_computation : conditional->branch_computations()) {
1707           TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
1708               conditional, branch_computation->root_instruction(), index,
1709               buffer_assignment));
1710         }
1711         return Status::OK();
1712       }));
1713   for (int j = 0; j < conditional->branch_count(); ++j) {
1714     TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
1715         conditional->operand(j + 1)->shape(),
1716         [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status {
1717           return CheckHloBuffersShareAllocation(
1718               conditional->operand(j + 1),
1719               conditional->branch_computation(j)->parameter_instruction(0),
1720               index, buffer_assignment);
1721         }));
1722   }
1723   return Status::OK();
1724 }
1725 
1726 }  // namespace
1727 
BuildWhileThunk(const HloInstruction * hlo)1728 std::unique_ptr<Thunk> IrEmitterUnnested::BuildWhileThunk(
1729     const HloInstruction* hlo) {
1730   // Check that all while-related buffers share an allocation.
1731   TF_CHECK_OK(CheckWhileBuffersShareAllocation(
1732       hlo, ir_emitter_context_->buffer_assignment()));
1733 
1734   // Generate thunk sequence for while 'condition'.
1735   HloComputation* condition = hlo->while_condition();
1736   IrEmitterUnnested ir_emitter_condition(hlo_module_config_, condition,
1737                                          ir_emitter_context_);
1738   TF_CHECK_OK(condition->Accept(&ir_emitter_condition));
1739 
1740   // Generate thunk sequence for while 'body'.
1741   HloComputation* body = hlo->while_body();
1742   IrEmitterUnnested ir_emitter_body(hlo_module_config_, body,
1743                                     ir_emitter_context_);
1744   TF_CHECK_OK(body->Accept(&ir_emitter_body));
1745 
1746   return absl::make_unique<WhileThunk>(
1747       GetAllocationSlice(*condition->root_instruction()),  // cond result
1748       ir_emitter_condition.ConsumeThunkSequence(),
1749       ir_emitter_body.ConsumeThunkSequence(), hlo);
1750 }
1751 
BuildForThunk(const HloInstruction * hlo,const int64 loop_limit)1752 std::unique_ptr<Thunk> IrEmitterUnnested::BuildForThunk(
1753     const HloInstruction* hlo, const int64 loop_limit) {
1754   // Check that all while-related buffers share an allocation.
1755   TF_CHECK_OK(CheckWhileBuffersShareAllocation(
1756       hlo, ir_emitter_context_->buffer_assignment()));
1757 
1758   // Generate thunk sequence for while 'body' (will be used a For loop body).
1759   HloComputation* body = hlo->while_body();
1760   IrEmitterUnnested ir_emitter_body(hlo_module_config_, body,
1761                                     ir_emitter_context_);
1762   TF_CHECK_OK(body->Accept(&ir_emitter_body));
1763 
1764   return absl::make_unique<ForThunk>(
1765       loop_limit, ir_emitter_body.ConsumeThunkSequence(), hlo);
1766 }
1767 
BuildConditionalThunk(const HloInstruction * hlo)1768 std::unique_ptr<Thunk> IrEmitterUnnested::BuildConditionalThunk(
1769     const HloInstruction* hlo) {
1770   // Check that the buffers used in conditional are shared with the operands and
1771   // result appropriately.
1772   TF_CHECK_OK(CheckConditionalBuffersShareAllocation(
1773       hlo, ir_emitter_context_->buffer_assignment()));
1774 
1775   std::vector<BufferAllocation::Slice> branch_operands;
1776   std::vector<ThunkSequence> branch_thunks;
1777   for (int j = 0; j < hlo->branch_count(); ++j) {
1778     branch_operands.emplace_back(GetAllocationSlice(*hlo->operand(j + 1)));
1779     HloComputation* branch_computation = hlo->branch_computation(j);
1780     IrEmitterUnnested ir_emitter(hlo_module_config_, branch_computation,
1781                                  ir_emitter_context_);
1782     TF_CHECK_OK(branch_computation->Accept(&ir_emitter));
1783     branch_thunks.push_back(std::move(*ir_emitter.ConsumeThunkSequence()));
1784   }
1785 
1786   return absl::make_unique<ConditionalThunk>(
1787       GetAllocationSlice(*hlo->operand(0)), branch_operands,
1788       std::move(branch_thunks), hlo);
1789 }
1790 
EmitTargetElementLoopInThunk(const HloInstruction & hlo,const llvm_ir::ElementGenerator & element_generator,KernelThunk * thunk)1791 Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
1792     const HloInstruction& hlo,
1793     const llvm_ir::ElementGenerator& element_generator, KernelThunk* thunk) {
1794   int unroll_factor = thunk->unroll_factor();
1795   VLOG(3) << bindings_.ToString();
1796 
1797   bool multi_output = hlo.shape().IsTuple();
1798 
1799   const Shape& element_shape =
1800       multi_output ? ShapeUtil::GetSubshape(hlo.shape(), {0}) : hlo.shape();
1801   VLOG(3) << "EmitTargetElementLoopInThunk "
1802           << ShapeUtil::HumanStringWithLayout(hlo.shape())
1803           << " for unroll_factor " << unroll_factor;
1804   LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
1805       element_shape, ir_emitter_context_->device_description(), unroll_factor);
1806   UpdateLaunchDimensions(launch_dimensions, thunk,
1807                          ir_emitter_context_->llvm_module());
1808   if (!multi_output) {
1809     return ParallelLoopEmitter(element_generator, GetIrArray(hlo, hlo),
1810                                launch_dimensions, &b_, unroll_factor)
1811         .EmitLoop(
1812             IrName(&hlo),
1813             GetIndexTypeForKernel(&hlo, launch_dimensions.launch_bound(), &b_));
1814   }
1815 
1816   // Emit the tuple pointers in one thread.  We could do this at any point in
1817   // the kernel, but we do it at the beginning in the hopes of reducing register
1818   // pressure, since we touch threadIdx.x and blockIdx.x at the beginning of the
1819   // kernel *anyway*.
1820   std::vector<IrArray> output_arrays = ConstructIrArrayForOutputs(hlo);
1821   KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] {
1822     llvm_ir::EmitTuple(GetIrArray(hlo, hlo), output_arrays, &b_);
1823   });
1824 
1825   // For multioutput fusion, we need to emit each operand and the root.
1826   TF_RETURN_IF_ERROR(
1827       ParallelLoopEmitter(element_generator, output_arrays, launch_dimensions,
1828                           &b_, unroll_factor)
1829           .EmitLoop(IrName(&hlo),
1830                     GetIndexTypeForKernel(
1831                         &hlo, launch_dimensions.launch_bound(), &b_)));
1832 
1833   b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator());
1834   return Status::OK();
1835 }
1836 
1837 namespace {
1838 
1839 // Returns true if the fusion contains any instruction that is likely
1840 // translated to complex LLVM IR, such as loops, and prevent vectorization.
MayPreventVectorization(const HloInstruction & hlo)1841 bool MayPreventVectorization(const HloInstruction& hlo) {
1842   if (hlo.opcode() == HloOpcode::kFusion) {
1843     return absl::c_any_of(hlo.fused_instructions_computation()->instructions(),
1844                           [](const HloInstruction* instr) {
1845                             switch (instr->opcode()) {
1846                               case HloOpcode::kReduce:
1847                               case HloOpcode::kReduceWindow:
1848                               case HloOpcode::kSort:
1849                               case HloOpcode::kDot:
1850                               case HloOpcode::kSin:
1851                               case HloOpcode::kCos:
1852                               case HloOpcode::kPower:
1853                               case HloOpcode::kAtan2:
1854                                 return true;
1855                               default:
1856                                 return false;
1857                             }
1858                           });
1859   } else if (hlo.IsElementwise()) {
1860     // Unfused elementwise operations are usually memory bound, unroll them.
1861     switch (hlo.opcode()) {
1862         // The following elementwise operation implementations contain branches.
1863         // LLVM vectorizer doesn't work in that case.
1864         // The unrolled code is faster when it isn't vectorized.
1865       case HloOpcode::kSin:
1866       case HloOpcode::kCos:
1867       case HloOpcode::kPower:
1868       case HloOpcode::kAtan2:
1869         return true;
1870       default:
1871         return false;
1872     }
1873   }
1874   return true;
1875 }
1876 
1877 }  // namespace
1878 
EmitTargetElementLoop(const HloInstruction & hlo,const llvm_ir::ElementGenerator & element_generator)1879 Status IrEmitterUnnested::EmitTargetElementLoop(
1880     const HloInstruction& hlo,
1881     const llvm_ir::ElementGenerator& element_generator) {
1882   int unroll_factor = 1;
1883   if (!MayPreventVectorization(hlo)) {
1884     unroll_factor = ComputeMaxUnrollFactor(&hlo);
1885   }
1886 
1887   std::unique_ptr<KernelThunk> kernel_thunk = BuildKernelThunk(
1888       &hlo, /*implements_whole_instruction=*/true, unroll_factor);
1889   Status emit_status =
1890       EmitTargetElementLoopInThunk(hlo, element_generator, kernel_thunk.get());
1891   thunk_sequence_->emplace_back(std::move(kernel_thunk));
1892 
1893   return emit_status;
1894 }
1895 
1896 // Gets the output offset as calculated from thread_id.x (to be applied to the
1897 // offset calculated from block_id and thread_id.y).
GetStartOffsetX(const KernelMappingScheme & mapping_scheme,llvm::Value * thread_id_x,llvm::Type * index_ty,llvm::IRBuilder<> * b)1898 static llvm::Value* GetStartOffsetX(const KernelMappingScheme& mapping_scheme,
1899                                     llvm::Value* thread_id_x,
1900                                     llvm::Type* index_ty,
1901                                     llvm::IRBuilder<>* b) {
1902   if (mapping_scheme.DilatedX()) {
1903     return thread_id_x;
1904   }
1905   int64 x_num_steps =
1906       mapping_scheme.GetTileSizeX() / mapping_scheme.GetNumThreadsX();
1907   return b->CreateMul(thread_id_x,
1908                       llvm::ConstantInt::get(index_ty, x_num_steps));
1909 }
1910 
1911 // Emits code to process up to
1912 // (tile_size_x/num_threads_x * tile_size_y/num_threads_y) elements in a tile,
1913 // given `emit_elem_function` is the function to emit code to process one
1914 // element, `y` and `x` are the intra-tile coordinates for the first element
1915 // to process, and `index` is the index for the origin of the tile. Information
1916 // about tile_size_x/y and num_threads_x/y are stored in `mapping_scheme`. Emits
1917 // bounds check to ensure that each processed element is within the boundary
1918 // defined by `tile_width` and `tile_height`.
1919 //
1920 // Pseudocode:
1921 //
1922 // for (y_loc = 0; y_loc < tile_height; y_loc += num_threads_y) {
1923 //   for (j = 0; j < tile_size_x / num_threads_x; j++) { // unrolled
1924 //     if (dilated) {
1925 //       x_loc = x + j * num_threads_x;
1926 //     } else {
1927 //       x_loc = x * (tile_size_x / num_threads_x) + j;
1928 //     }
1929 //
1930 //     if (x_loc < tile_width) {
1931 //       emit_elem_function(y + y_loc, x_loc);
1932 //     }
1933 //   }
1934 // }
1935 //
EmitTile(const KernelMappingScheme & mapping_scheme,const IrArray::Index & tile_origin_index,const string & loop_name,KernelSupportLibrary * ksl,llvm::IRBuilder<> * b,llvm::Value * y,llvm::Value * x,llvm::Value * tile_height,llvm::Value * tile_width,const IrEmitterUnnested::EmitElementFunction & emit_elem_function)1936 static void EmitTile(
1937     const KernelMappingScheme& mapping_scheme,
1938     const IrArray::Index& tile_origin_index, const string& loop_name,
1939     KernelSupportLibrary* ksl, llvm::IRBuilder<>* b, llvm::Value* y,
1940     llvm::Value* x, llvm::Value* tile_height, llvm::Value* tile_width,
1941     const IrEmitterUnnested::EmitElementFunction& emit_elem_function) {
1942   llvm::Type* index_ty = tile_width->getType();
1943   auto constant = [&](int64 val) {
1944     return llvm::ConstantInt::get(index_ty, val);
1945   };
1946   int64 num_threads_x = mapping_scheme.GetNumThreadsX();
1947   int64 num_threads_y = mapping_scheme.GetNumThreadsY();
1948   int64 tile_size_x = mapping_scheme.GetTileSizeX();
1949 
1950   int64 x_num_steps = tile_size_x / num_threads_x;
1951   llvm::Value* start_offset_x = GetStartOffsetX(mapping_scheme, x, index_ty, b);
1952 
1953   // Using dilated mapping scheme, each thread steps with a stride of number
1954   // of threads.
1955   // Otherwise, the stride is one, but we multiply each offset by the limit of
1956   // number of steps which can be made.
1957   int64 step_x = mapping_scheme.DilatedX() ? num_threads_x : 1;
1958 
1959   IrArray::Index source_idx =
1960       tile_origin_index.AddOffsetToDim(start_offset_x, kDimX, b);
1961 
1962   ksl->For(
1963       loop_name + "_y_in_tile",
1964       /*start=*/y,
1965       /*end=*/tile_height,
1966       /*step=*/constant(num_threads_y), [&](llvm::Value* y_loc) {
1967         IrArray::Index source_idx_y =
1968             source_idx.AddOffsetToDim(y_loc, kDimY, b);
1969         for (int64 j = 0; j < x_num_steps; j++) {
1970           llvm::Value* x_loc =
1971               b->CreateAdd(constant(j * step_x), start_offset_x, "x_loc");
1972           IrArray::Index source_idx_x =
1973               source_idx_y.AddOffsetToDim(constant(j * step_x), kDimX, b);
1974           // The if-statement below always evaluates to true for the blocks
1975           // where the entire processed tile fits within the input buffer.
1976           ksl->If(loop_name + "_x_in_tile", b->CreateICmpULT(x_loc, tile_width),
1977                   [&] { emit_elem_function(source_idx_x, y_loc, x_loc, j); });
1978         }
1979       });
1980 }
1981 
1982 // Emits code to process a tensor element in a tile for the given kCopy HLO that
1983 // performs a 0-2-1 transpose.
1984 //
1985 // index: The index for the first output element in the normalized tensor. The
1986 //   normalized tensor is the resulting tensor after collapsing contiguous
1987 //   dimensions that play the same role in the transpose.
1988 // y_loc: The y coordinate within a tile.
1989 // x_loc: The x coordinate within a tile.
1990 // mapping_scheme: Kernel mapping scheme specifying the tiling
EmitTileElementForCopy(HloInstruction * hlo,const llvm_ir::IrArray::Index & index,const KernelMappingScheme & mapping_scheme,llvm::Value * y_loc,llvm::Value * x_loc,int64,absl::Span<llvm::Value * const> param_shmem_buffers)1991 void IrEmitterUnnested::EmitTileElementForCopy(
1992     HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
1993     const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
1994     llvm::Value* x_loc, int64 /*x_iter_num*/,
1995     absl::Span<llvm::Value* const> param_shmem_buffers) {
1996   // TODO(jlebar): Add AA metadata to this load.
1997   llvm::Instruction* load_from_shmem_buffer =
1998       Load(GEP(param_shmem_buffers[0], {b_.getInt64(0), x_loc, y_loc}),
1999            "output_element");
2000   llvm_ir::IrArray output_array = GetIrArray(*hlo, *hlo);
2001   Shape output_reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout(
2002       hlo->shape().element_type(), mapping_scheme.GetDimsInElems());
2003   // When the output_reduced_shape is a 0-2-1 transpose of the input shape,
2004   // the 0-2-1 transpose is achieved through EmitWriteArrayElement.
2005   output_array.CastToShape(output_reduced_shape, &b_)
2006       .EmitWriteArrayElement(index, load_from_shmem_buffer, &b_);
2007 }
2008 
GetUnnormalizedIndex(const IrArray::Index & normalized_shape_index,const Shape & unnormalized_shape,llvm::IRBuilder<> * b_,const KernelMappingScheme & kernel_mapping_scheme)2009 static IrArray::Index GetUnnormalizedIndex(
2010     const IrArray::Index& normalized_shape_index,
2011     const Shape& unnormalized_shape, llvm::IRBuilder<>* b_,
2012     const KernelMappingScheme& kernel_mapping_scheme) {
2013   DCHECK_EQ(normalized_shape_index.size(), 3);
2014   llvm::Value* linear = normalized_shape_index.Linearize(
2015       kernel_mapping_scheme.GetDimsInElems(), b_);
2016   return IrArray::Index(linear, unnormalized_shape, b_);
2017 }
2018 
2019 // Emits code to process a tensor element in a tile for the given kLoop fusion
2020 // HLO containing parameters that are 0-2-1 transpose of its outputs.
2021 //
2022 // index: The index for the first output element in the normalized tensor, that
2023 //   is the resulting tensor after collapsing contiguous dimensions that play
2024 //   the same role in the transpose.
2025 // kernel_info: Other information to support the kernel code generation.
2026 // y_loc: The y coordinate within a tile.
2027 // x_loc: The x coordinate within a tile.
EmitTileElementForFusion(HloInstruction * hlo,const llvm_ir::IrArray::Index & index,const KernelMappingScheme & mapping_scheme,llvm::Value * y_loc,llvm::Value * x_loc,int64,absl::Span<llvm::Value * const> param_shmem_buffers)2028 void IrEmitterUnnested::EmitTileElementForFusion(
2029     HloInstruction* hlo, const llvm_ir::IrArray::Index& index,
2030     const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
2031     llvm::Value* x_loc, int64 /*x_iter_num*/,
2032     absl::Span<llvm::Value* const> param_shmem_buffers) {
2033   std::vector<IrArray> output_arrays = ConstructIrArrayForOutputs(*hlo);
2034   GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
2035                                      GetNestedComputer());
2036   FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(hlo),
2037                                &elem_emitter, x_loc, y_loc,
2038                                param_shmem_buffers);
2039 
2040   TF_CHECK_OK(hlo->fused_expression_root()->Accept(&fused_emitter));
2041   IrArray::Index untiled_index = GetUnnormalizedIndex(
2042       index, output_arrays[0].GetShape(), &b_, mapping_scheme);
2043   const llvm_ir::ElementGenerator& output_generator =
2044       fused_emitter.GetRootGenerator();
2045   llvm::Value* output_value = output_generator(untiled_index).ValueOrDie();
2046   if (hlo->IsMultiOutputFusion()) {
2047     DCHECK(output_value->getType()->isStructTy());
2048     DCHECK_EQ(output_value->getType()->getStructNumElements(),
2049               output_arrays.size());
2050     for (int64 i = 0; i < output_arrays.size(); ++i) {
2051       output_arrays[i].EmitWriteArrayElement(
2052           untiled_index, ExtractValue(output_value, i), &b_);
2053     }
2054   } else {
2055     output_arrays[0].EmitWriteArrayElement(untiled_index, output_value, &b_);
2056   }
2057 }
2058 
2059 // Gets the number of partial results accumulated by a single thread performing
2060 // reduction.
GetNumberOfPartialResults(const ReductionCodegenInfo & reduction_info)2061 static int GetNumberOfPartialResults(
2062     const ReductionCodegenInfo& reduction_info) {
2063   const KernelMappingScheme& mapping_scheme =
2064       reduction_info.GetKernelMappingScheme();
2065   if (reduction_info.IsRowReduction()) {
2066     return 1;
2067   }
2068   int64 num_partial_results = mapping_scheme.DilatedX() ? 1 : 2;
2069   CHECK_EQ(num_partial_results,
2070            (mapping_scheme.GetTileSizeX() / mapping_scheme.GetNumThreadsX()));
2071   return num_partial_results;
2072 }
2073 
EmitPrologueForReduction(HloInstruction * unnested_hlo,ReductionCodegenInfo * reduction_info,absl::Span<HloInstruction * const> reduce_instructions,llvm::Type * index_type)2074 void IrEmitterUnnested::EmitPrologueForReduction(
2075     HloInstruction* unnested_hlo, ReductionCodegenInfo* reduction_info,
2076     absl::Span<HloInstruction* const> reduce_instructions,
2077     llvm::Type* index_type) {
2078   VLOG(10) << "Emit prologue for reduction: " << unnested_hlo->ToString();
2079   GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
2080                                           ir_emitter_context_->llvm_module(),
2081                                           &b_, GetNestedComputer());
2082   const HloInstruction* first_reduce = nullptr;
2083   for (int i = 0; i < reduce_instructions.size(); i++) {
2084     HloInstruction* reduce_inst = reduce_instructions[i];
2085     VLOG(10) << "Emit prologue for reduction: " << reduce_inst->ToString();
2086     if (first_reduce == nullptr) {
2087       first_reduce = reduce_inst;
2088     } else {
2089       CHECK(first_reduce->dimensions() == reduce_inst->dimensions());
2090     }
2091 
2092     AddressVector* reduction_input_addresses =
2093         reduction_info->GetMutableReductionInputAddresses();
2094     llvm::Type* element_type =
2095         llvm_ir::PrimitiveTypeToIrType(reduce_inst->shape().element_type(),
2096                                        ir_emitter_context_->llvm_module());
2097     llvm::AllocaInst* reduction_input_address = Alloca(element_type);
2098     reduction_input_addresses->push_back(reduction_input_address);
2099 
2100     int num_partial_results = GetNumberOfPartialResults(*reduction_info);
2101     AddressVector* partial_result_addresses =
2102         reduction_info->GetMutablePartialResultAddresses();
2103     llvm::AllocaInst* partial_result_address =
2104         Alloca(element_type, /*ArraySize=*/b_.getInt32(num_partial_results),
2105                "partial_reduction_result." + llvm::Twine(i));
2106     partial_result_addresses->push_back(partial_result_address);
2107 
2108     // Initialize the partial result with the initial value of the reduction.
2109     llvm::Value* init_ir_value;
2110     const HloInstruction* init_value = reduce_inst->operand(1);
2111     if (unnested_hlo->opcode() == HloOpcode::kFusion) {
2112       FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(unnested_hlo),
2113                                    &elemental_emitter);
2114 
2115       TF_CHECK_OK(init_value->Accept(&fused_emitter));
2116       init_ir_value =
2117           fused_emitter
2118               .GetGenerator(init_value)(IrArray::Index(b_.getInt32Ty()))
2119               .ValueOrDie();
2120     } else {
2121       init_ir_value =
2122           GetIrArray(*init_value, *unnested_hlo)
2123               .EmitReadArrayElement(IrArray::Index(b_.getInt32Ty()), &b_);
2124     }
2125 
2126     for (int i = 0; i < num_partial_results; ++i) {
2127       Store(init_ir_value,
2128             InBoundsGEP(partial_result_address, {b_.getInt32(i)}));
2129     }
2130   }
2131 }
2132 
EmitFullWarpShuffleDownLoopForAllReduces(absl::Span<HloComputation * const> reducers,absl::Span<llvm::AllocaInst * const> partial_result_addresses)2133 void IrEmitterUnnested::EmitFullWarpShuffleDownLoopForAllReduces(
2134     absl::Span<HloComputation* const> reducers,
2135     absl::Span<llvm::AllocaInst* const> partial_result_addresses) {
2136   CHECK_EQ(reducers.size(), partial_result_addresses.size());
2137   for (int i = 0; i != reducers.size(); i++) {
2138     EmitFullWarpShuffleDownLoopForReduce(
2139         reducers[i], partial_result_addresses[i]->getType()->getElementType(),
2140         partial_result_addresses[i]);
2141   }
2142 }
2143 
EmitFullWarpShuffleDownLoopForReduce(HloComputation * reducer,llvm::Type * element_type,llvm::Value * partial_result_address)2144 void IrEmitterUnnested::EmitFullWarpShuffleDownLoopForReduce(
2145     HloComputation* reducer, llvm::Type* element_type,
2146     llvm::Value* partial_result_address) {
2147   for (int distance = 16; distance >= 1; distance /= 2) {
2148     int bit_width = llvm_ir::GetSizeInBits(element_type);
2149     llvm::Value* result_from_other_lane =
2150         Alloca(element_type, nullptr, "result_from_other_lane");
2151     // Bitcast cannot be applied to aggregate types (even packed ones), so
2152     // we bitcast addresses of load/store to intN* of the same bit-width.
2153     llvm::Type* shuffled_value_type =
2154         element_type->isStructTy() ? b_.getIntNTy(bit_width) : element_type;
2155     auto convert_pointer_for_shuffle = [&](llvm::Value* ptr) {
2156       return b_.CreatePointerBitCastOrAddrSpaceCast(
2157           ptr, shuffled_value_type->getPointerTo());
2158     };
2159     llvm::Value* partial_result =
2160         Load(convert_pointer_for_shuffle(partial_result_address),
2161              "partial_reduction_result");
2162     Store(EmitFullWarpShuffleDown(partial_result, b_.getInt32(distance), &b_),
2163           convert_pointer_for_shuffle(result_from_other_lane));
2164     TF_CHECK_OK(EmitCallToNestedComputation(
2165         *reducer, {partial_result_address, result_from_other_lane},
2166         partial_result_address));
2167   }
2168 }
2169 
2170 // Given the IrArray index of a reduction input, returns the linear address of
2171 // the reduction output as if the reduction were going to keep the input shape
2172 // with the dimensions being reduced moved.
GetUntransposedOutputLinearAddress(llvm::IRBuilder<> * b,const llvm_ir::IrArray::Index & index,const ReductionCodegenInfo & reduction_info)2173 static llvm::Value* GetUntransposedOutputLinearAddress(
2174     llvm::IRBuilder<>* b, const llvm_ir::IrArray::Index& index,
2175     const ReductionCodegenInfo& reduction_info) {
2176   const KernelMappingScheme& kernel_mapping_scheme =
2177       reduction_info.GetKernelMappingScheme();
2178   if (reduction_info.IsRowReduction()) {
2179     return index[kDimY];
2180   }
2181   absl::Span<const int64> dims_in_elem = kernel_mapping_scheme.GetDimsInElems();
2182   llvm::Value* x_dim_size = index.GetConstantWithIndexType(dims_in_elem[kDimX]);
2183   llvm::Value* x_block_offset = b->CreateMul(index[kDimZ], x_dim_size);
2184   return b->CreateAdd(x_block_offset, index[kDimX]);
2185 }
2186 
EmitEpilogueForReduction(llvm::Type * index_ty,HloInstruction * unnested_hlo,const ReductionCodegenInfo & reduction_info,absl::Span<const HloInstruction * const> reduce_instructions,absl::Span<const ShapeIndex> reduction_output_shape_indices,absl::Span<HloComputation * const> reducers,const TilingKernelInfo & tiling_kernel_info)2187 void IrEmitterUnnested::EmitEpilogueForReduction(
2188     llvm::Type* index_ty, HloInstruction* unnested_hlo,
2189     const ReductionCodegenInfo& reduction_info,
2190     absl::Span<const HloInstruction* const> reduce_instructions,
2191     absl::Span<const ShapeIndex> reduction_output_shape_indices,
2192     absl::Span<HloComputation* const> reducers,
2193     const TilingKernelInfo& tiling_kernel_info) {
2194   const KernelMappingScheme& mapping_scheme =
2195       reduction_info.GetKernelMappingScheme();
2196   auto constant = [&](uint64 c) -> llvm::Constant* {
2197     return llvm::ConstantInt::get(index_ty, c);
2198   };
2199 
2200   IrEmitterUnnested::ThreadIdInfo thread_id_info =
2201       EmitThreadIdInfo(mapping_scheme.GetThreadsPerBlock(), index_ty,
2202                        mapping_scheme.GetNumThreadsX());
2203   llvm::Value* start_offset_x = GetStartOffsetX(
2204       mapping_scheme, thread_id_info.thread_id_x, index_ty, &b_);
2205 
2206   IrArray::Index start_offset =
2207       tiling_kernel_info.tile_origin
2208           .AddOffsetToDim(thread_id_info.thread_id_y, kDimY, &b_)
2209           .AddOffsetToDim(start_offset_x, kDimX, &b_);
2210 
2211   int num_reduces = reducers.size();
2212   absl::Span<llvm::AllocaInst* const> partial_result_addresses =
2213       reduction_info.GetPartialResultAddresses();
2214   if (reduction_info.IsRowReduction()) {
2215     EmitFullWarpShuffleDownLoopForAllReduces(reducers,
2216                                              partial_result_addresses);
2217     llvm_ir::LlvmIfData if_lane_id_is_zero_data = llvm_ir::EmitIfThenElse(
2218         ICmpEQ(thread_id_info.lane_id, constant(0)), "lane_id_is_zero", &b_);
2219     llvm_ir::SetToFirstInsertPoint(if_lane_id_is_zero_data.true_block, &b_);
2220   } else {
2221     // Some threads in the block are completely outside of the bound of the
2222     // tensor, so they should not do anything at all.
2223     llvm::Value* has_output = b_.CreateAnd(
2224         b_.CreateICmpULT(start_offset_x,
2225                          tiling_kernel_info.output_tile_bounds[kDimX]),
2226         b_.CreateICmpULT(thread_id_info.thread_id_y,
2227                          tiling_kernel_info.output_tile_bounds[kDimY]));
2228     llvm_ir::LlvmIfData if_has_output =
2229         llvm_ir::EmitIfThenElse(has_output, "output_inbound", &b_);
2230     llvm_ir::SetToFirstInsertPoint(if_has_output.true_block, &b_);
2231   }
2232 
2233   int num_partial_results = GetNumberOfPartialResults(reduction_info);
2234 
2235   // Emit an atomic operation that accumulates the partial reduction to the
2236   // output element. For row reduction, this is only for lane 0 due to the
2237   // if-statement emitted above.
2238   for (int i = 0; i != num_reduces; ++i) {
2239     const HloInstruction* reduce_hlo = reduce_instructions[i];
2240     Shape reduction_kept_element_shape = ShapeUtil::FilterDimensions(
2241         [&](int64 dim) {
2242           return !absl::c_linear_search(reduce_hlo->dimensions(), dim);
2243         },
2244         reduce_hlo->operand(0)->shape());
2245     for (int j = 0; j < num_partial_results; ++j) {
2246       llvm::Value* untransposed_output_linear_address =
2247           GetUntransposedOutputLinearAddress(
2248               &b_, start_offset.AddOffsetToDim(constant(j), kDimX, &b_),
2249               reduction_info);
2250 
2251       // A reduction is allowed to transpose its output.  For example, suppose
2252       // we are reducing the second dimension of f32[10,20,30]{3,2,1}.  We are
2253       // allowed to produce as output either f32[10,30]{1,0} (no transpose) or
2254       // f32[10,30]{0,1} (transposing the two output dims).
2255       //
2256       // At this point in the function we have a "partial sum" of input elements
2257       // (stored in partial_result_addresses), and we need to accumulate it into
2258       // the correct output element.
2259       auto output_array = GetIrArray(*unnested_hlo, *unnested_hlo,
2260                                      reduction_output_shape_indices[i]);
2261       IrArray::Index element_index(
2262           /*linear=*/untransposed_output_linear_address,
2263           reduction_kept_element_shape, &b_);
2264       IrArray::Index output_index(element_index.multidim(),
2265                                   output_array.GetShape(),
2266                                   element_index.GetType());
2267       llvm::Value* output_address = output_array.EmitArrayElementAddress(
2268           output_index, &b_, "output_element_address");
2269       TF_CHECK_OK(EmitAtomicOperationForNestedComputation(
2270           *reducers[i], output_address,
2271           InBoundsGEP(partial_result_addresses[i], {constant(j)})));
2272     }
2273   }
2274 }
2275 
EmitBlockId()2276 llvm::Value* IrEmitterUnnested::EmitBlockId() {
2277   return gpu::EmitCallToTargetIntrinsic(gpu::TargetIntrinsicID::kBlockIdx, {},
2278                                         {}, &b_);
2279 }
2280 
EmitPrintfWithThreadId(absl::string_view fmt,absl::Span<llvm::Value * const> arguments,absl::optional<int64> thread_id_filter,absl::optional<int64> block_id_filter)2281 void IrEmitterUnnested::EmitPrintfWithThreadId(
2282     absl::string_view fmt, absl::Span<llvm::Value* const> arguments,
2283     absl::optional<int64> thread_id_filter,
2284     absl::optional<int64> block_id_filter) {
2285   llvm::Value* thread_id = EmitThreadId(1024, b_.getInt32Ty());
2286   llvm::Value* block_id = EmitBlockId();
2287   std::vector<llvm::Value*> updated_arguments = {thread_id, block_id};
2288   updated_arguments.insert(updated_arguments.end(), arguments.begin(),
2289                            arguments.end());
2290   llvm::Value* constraint = b_.getTrue();
2291   if (thread_id_filter) {
2292     constraint = b_.CreateAnd(
2293         constraint, b_.CreateICmpEQ(thread_id, b_.getInt32(*thread_id_filter)));
2294   }
2295   if (block_id_filter) {
2296     constraint = b_.CreateAnd(
2297         constraint, b_.CreateICmpEQ(block_id, b_.getInt32(*block_id_filter)));
2298   }
2299   KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
2300   ksl.If(constraint, [&] {
2301     ::xla::gpu::EmitPrintf(absl::StrCat("[TID=%d,BID=%d] ", fmt, "\n"),
2302                            updated_arguments, &b_);
2303   });
2304 }
2305 
EmitTileElementForReduction(HloInstruction * unnested_hlo,const Shape & reduction_operand_shape,absl::Span<HloInstruction * const> output_instructions,const llvm_ir::IrArray::Index & index,const ReductionCodegenInfo & reduction_info,absl::Span<HloComputation * const> reducers,int64 x_iter_num)2306 void IrEmitterUnnested::EmitTileElementForReduction(
2307     HloInstruction* unnested_hlo, const Shape& reduction_operand_shape,
2308     absl::Span<HloInstruction* const> output_instructions,
2309     const llvm_ir::IrArray::Index& index,
2310     const ReductionCodegenInfo& reduction_info,
2311     absl::Span<HloComputation* const> reducers, int64 x_iter_num) {
2312   VLOG(10) << "Emit tile element for reduce " << unnested_hlo->ToString();
2313   bool returns_tuple = output_instructions.size() > 1;
2314   int partial_result_index = reduction_info.IsRowReduction() ? 0 : x_iter_num;
2315 
2316   InlinedVector<llvm_ir::ElementGenerator, 1> input_gens;
2317   std::vector<std::pair<llvm_ir::ElementGenerator, ShapeIndex>>
2318       extra_output_gens;
2319   GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
2320                                      GetNestedComputer());
2321   FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(unnested_hlo),
2322                                &elem_emitter);
2323   // Construct the ElementGenerator for each reduction and extra output in the
2324   // the group of output instructions.
2325   if (unnested_hlo->opcode() == HloOpcode::kFusion) {
2326     TF_CHECK_OK(unnested_hlo->fused_expression_root()->Accept(&fused_emitter));
2327 
2328     for (int i = 0, e = output_instructions.size(); i != e; ++i) {
2329       const HloInstruction* inst = output_instructions[i];
2330       ShapeIndex idx = returns_tuple ? ShapeIndex({i}) : ShapeIndex({});
2331       if (IsReductionFromOrToContiguousDimensions(*inst)) {
2332         input_gens.push_back(fused_emitter.GetGenerator(inst->operand(0)));
2333       } else {
2334         extra_output_gens.emplace_back(fused_emitter.GetGenerator(inst),
2335                                        std::move(idx));
2336       }
2337     }
2338   } else {
2339     input_gens.push_back([&](const IrArray::Index& index) {
2340       return GetIrArray(*unnested_hlo->operand(0), *unnested_hlo)
2341           .EmitReadArrayElement(index, &b_);
2342     });
2343   }
2344 
2345   IrArray::Index input_index =
2346       GetUnnormalizedIndex(index, reduction_operand_shape, &b_,
2347                            reduction_info.GetKernelMappingScheme());
2348   // Clear the linear index field of the IrArray::Index to enable the use of
2349   // GetElementPointer with array types. This enables the vectorization of
2350   // the computation for different partial results. Use this index if
2351   // 'num_partial_results > 1'.
2352   int num_partial_results = GetNumberOfPartialResults(reduction_info);
2353   auto index_without_linear = IrArray::Index(
2354       input_index.multidim(), reduction_operand_shape, input_index.GetType());
2355 
2356   // Emit code to generate the input and perform the reduction computation for
2357   // each reduction instruction.
2358   for (int i = 0; i != reducers.size(); ++i) {
2359     llvm::AllocaInst* input_address =
2360         reduction_info.GetReductionInputAddresses()[i];
2361     llvm::AllocaInst* partial_reduction_result_address =
2362         reduction_info.GetPartialResultAddresses()[i];
2363     llvm::Value* const input_ir_value =
2364         input_gens[i](num_partial_results > 1 ? index_without_linear
2365                                               : input_index)
2366             .ValueOrDie();
2367     Store(input_ir_value, input_address);
2368     llvm::Value* partial_result_address = InBoundsGEP(
2369         partial_reduction_result_address, {b_.getInt32(partial_result_index)});
2370     TF_CHECK_OK(EmitCallToNestedComputation(
2371         *reducers[i], {partial_result_address, input_address},
2372         partial_result_address));
2373   }
2374 
2375   // Emit code to generate the output for the non-reduction instructions in the
2376   // fusion, if any.
2377   TF_CHECK_OK(EmitExtraOutputsForReduce(
2378       unnested_hlo, input_index,
2379       /*use_linear_index=*/num_partial_results == 1, extra_output_gens));
2380 }
2381 
EmitThreadId(int64 threads_per_block,llvm::Type * index_ty)2382 llvm::Value* IrEmitterUnnested::EmitThreadId(int64 threads_per_block,
2383                                              llvm::Type* index_ty) {
2384   // Calculate (y, x) coordinates respectively in the 2D view of thread block,
2385   // defined by (num_thread_y, num_thread_x) from thread_id.
2386   llvm::CallInst* thread_id_raw = gpu::EmitCallToTargetIntrinsic(
2387       gpu::TargetIntrinsicID::kThreadIdx, {}, {}, &b_);
2388   llvm_ir::AddRangeMetadata(0, threads_per_block, thread_id_raw);
2389   return b_.CreateIntCast(thread_id_raw, index_ty,
2390                           /*isSigned=*/true, "thread.id.x");
2391 }
2392 
EmitThreadIdInfo(int64 threads_per_block,llvm::Type * index_ty,int64 num_threads_x)2393 IrEmitterUnnested::ThreadIdInfo IrEmitterUnnested::EmitThreadIdInfo(
2394     int64 threads_per_block, llvm::Type* index_ty, int64 num_threads_x) {
2395   auto constant = [&](uint64 c) -> llvm::Constant* {
2396     return llvm::ConstantInt::get(index_ty, c);
2397   };
2398   llvm::Value* thread_id = EmitThreadId(threads_per_block, index_ty);
2399   llvm::Value* num_threads_x_v = constant(num_threads_x);
2400   return {
2401       /*thread_id=*/thread_id,
2402       /*thread_id_x=*/b_.CreateURem(thread_id, num_threads_x_v, "thread_id.x"),
2403       /*thread_id_y=*/b_.CreateUDiv(thread_id, num_threads_x_v, "thread_id.y"),
2404       /*lane_id=*/b_.CreateURem(thread_id, constant(kWarpSize), "lane_id")};
2405 }
2406 
EmitTilingKernel(const KernelMappingScheme & mapping_scheme,llvm::Type * index_ty,const TileElementGenerator & tile_element_generator)2407 IrEmitterUnnested::TilingKernelInfo IrEmitterUnnested::EmitTilingKernel(
2408     const KernelMappingScheme& mapping_scheme, llvm::Type* index_ty,
2409     const TileElementGenerator& tile_element_generator) {
2410   absl::Span<const int64> dims_in_elems = mapping_scheme.GetDimsInElems();
2411   std::vector<int64> dims_in_blocks = {
2412       CeilOfRatio(dims_in_elems[0], mapping_scheme.GetTileSizeZ()),
2413       CeilOfRatio(dims_in_elems[1], mapping_scheme.GetTileSizeY()),
2414       CeilOfRatio(dims_in_elems[2], mapping_scheme.GetTileSizeX())};
2415   auto constant = [&](uint64 c) -> llvm::Constant* {
2416     return llvm::ConstantInt::get(index_ty, c);
2417   };
2418 
2419   IrEmitterUnnested::ThreadIdInfo thread_id_info =
2420       EmitThreadIdInfo(mapping_scheme.GetThreadsPerBlock(), index_ty,
2421                        mapping_scheme.GetNumThreadsX());
2422 
2423   KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
2424 
2425   const IrArray::Index block_coords = [&] {
2426     llvm::Value* block_id = EmitBlockId();
2427     llvm_ir::AddRangeMetadata(0, mapping_scheme.GetNumberOfBlocks(),
2428                               llvm::cast<llvm::Instruction>(block_id));
2429     llvm::Value* linear_block_id =
2430         b_.CreateIntCast(block_id, index_ty, /*isSigned=*/true, "block.id.x");
2431     IrArray::Index starting_block(linear_block_id,
2432                                   ShapeUtil::MakeShapeWithDescendingLayout(
2433                                       PRED /*arbitrary*/, dims_in_blocks),
2434                                   &b_);
2435 
2436     std::vector<llvm::Value*> multidim = {
2437         b_.CreateMul(starting_block[0], constant(mapping_scheme.GetTileSizeZ()),
2438                      "block_origin.z"),
2439         starting_block[1], starting_block[2]};
2440     return IrArray::Index(multidim, dims_in_blocks, index_ty);
2441   }();
2442 
2443   std::array<llvm::Value*, 3> output_tile_bounds;
2444   for (int i = kDimY; i < kDimTot; ++i) {
2445     int64 tile_size_for_dim = mapping_scheme.GetTileSizeFor(i);
2446     // Only last row or column may not have full size.
2447     llvm::Value* is_last =
2448         b_.CreateICmpEQ(block_coords[i], constant(dims_in_blocks[i] - 1));
2449     int64 partial_row =
2450         dims_in_elems[i] - (dims_in_blocks[i] - 1) * tile_size_for_dim;
2451     output_tile_bounds[i] =
2452         b_.CreateSelect(is_last, constant(partial_row),
2453                         constant(tile_size_for_dim), "tile_bound");
2454   }
2455 
2456   IrArray::Index tile_origin = [&] {
2457     std::vector<llvm::Value*> elem_multi_index = block_coords.multidim();
2458     llvm::Type* index_ty = block_coords.GetType();
2459     for (int i = kDimY; i < kDimTot; ++i) {
2460       elem_multi_index[i] = b_.CreateMul(
2461           block_coords[i],
2462           llvm::ConstantInt::get(index_ty, mapping_scheme.GetTileSizeFor(i)),
2463           "tile_origin." + std::to_string(i));
2464     }
2465     return IrArray::Index(elem_multi_index, mapping_scheme.GetDimsInElems(),
2466                           index_ty);
2467   }();
2468 
2469   auto emit_tile = [&](const IrArray::Index& tile) {
2470     tile_element_generator(thread_id_info.thread_id_y,
2471                            thread_id_info.thread_id_x, tile, "output",
2472                            output_tile_bounds[1], output_tile_bounds[2], &ksl);
2473   };
2474 
2475   if (mapping_scheme.GetTileSizeZ() == 1) {
2476     emit_tile(tile_origin);
2477   } else {
2478     llvm::Value* starting_tile_index_for_dim = tile_origin[kDimZ];
2479     llvm::Value* block_size_for_dim = constant(mapping_scheme.GetTileSizeZ());
2480     llvm::Value* block_id_for_dim =
2481         b_.CreateUDiv(starting_tile_index_for_dim, block_size_for_dim);
2482     llvm::Value* last_block_for_dim = constant(dims_in_blocks[kDimZ] - 1);
2483     llvm::Value* last_block_size_for_dim =
2484         constant(dims_in_elems[kDimZ] -
2485                  (dims_in_blocks[kDimZ] - 1) * mapping_scheme.GetTileSizeZ());
2486 
2487     llvm::Value* num_tiles_in_block =
2488         b_.CreateSelect(b_.CreateICmpEQ(last_block_for_dim, block_id_for_dim),
2489                         last_block_size_for_dim, block_size_for_dim);
2490     ksl.For("loop_z",
2491             /*start=*/constant(0),
2492             /*end=*/num_tiles_in_block,
2493             /*step=*/1, [&](llvm::Value* block_dim_induction_var) {
2494               IrArray::Index tile_index = tile_origin.AddOffsetToDim(
2495                   block_dim_induction_var, kDimZ, &b_);
2496               emit_tile(tile_index);
2497             });
2498   }
2499   return {output_tile_bounds, tile_origin};
2500 }
2501 
EmitSyncThreads()2502 llvm::CallInst* IrEmitterUnnested::EmitSyncThreads() {
2503   return EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, &b_);
2504 }
2505 
2506 // Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose
2507 // algorithm to improve the memory access patterns for the input parameters
2508 // with a shape that is a 0-2-1 transpose of the output tensor shape. The caller
2509 // is responsible for making sure that it is safe to apply the shared memory
2510 // transpose on the input parameters.
2511 //
2512 //
2513 // For the purpose of tiling, the output tensors have a logical shape of three
2514 // components 0-2-1 while the relevant input parameters have a logical shape
2515 // of three components 0-1-2 in the order major to minor. The x- and y-
2516 // dimensions of the tensors are tiled in square tiles with an edge length
2517 // `kTileSize`. Each thread block of `kTileSize` x `kNumRows` threads
2518 // transposes one tile: each thread copies kTileSize/kNumRows elements from
2519 // the input to a shared memory tile, then the otherwise "regular HLO kernel"
2520 // reads from the shared memory instead of the original input.
2521 //
2522 // This is similar to the following CUDA algorithm in TensorFlow:
2523 // https://goo.gl/MStRV6.
2524 //
2525 // `kTileSize` should usually be same as warp size. We currently choose 32 for
2526 // `kTileSize` and 4 for `kNumRows`. The CUDA algorithm uses 8 for `kNumRows`.
2527 //
2528 // TODO(b/33320379): Here each block transposes 1 tile. It may be more
2529 // efficient to launch fewer blocks so each transposes many tiles.
EmitHlo021Tile(HloInstruction * hlo,Thunk * kernel_thunk,absl::Span<const int64> reduced_output_dims,absl::Span<const int64> tiled_param_ids)2530 void IrEmitterUnnested::EmitHlo021Tile(
2531     HloInstruction* hlo, Thunk* kernel_thunk,
2532     absl::Span<const int64> reduced_output_dims,
2533     absl::Span<const int64> tiled_param_ids) {
2534   constexpr int kNumRows = 4;
2535   KernelMappingScheme mapping_scheme(reduced_output_dims,
2536                                      /*tile_sizes=*/{1, kWarpSize, kWarpSize},
2537                                      /*num_threads_y=*/kNumRows,
2538                                      /*num_threads_x=*/kWarpSize,
2539                                      /*is_dilated_x=*/false);
2540   LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(),
2541                                      mapping_scheme.GetThreadsPerBlock());
2542   llvm::Type* index_type =
2543       GetIndexTypeForKernel(hlo, launch_dimensions.launch_bound(), &b_);
2544   std::vector<IrArray> param_arrays;
2545 
2546   // For each tiled parameter, cast its input IrArray to the corresponding
2547   // reduced shape and keep the reduced shape live during IR emission.
2548   std::vector<IrArray> param_in_reduced_shape_arrays;
2549   std::vector<llvm::Value*> param_shmem_buffers(hlo->operand_count(), nullptr);
2550 
2551   auto get_shared_memory_buffer = [&](llvm::Type* elem_ty,
2552                                       absl::string_view buffer_name) {
2553     // For Nvidia GPUs, the warp size is 32 threads and the shared memory bank
2554     // is organized into 32-way. We usually use the warp size or a multiplier or
2555     // a the warp size as the size for tiling. This may cause all elements in
2556     // the same column of a tile use the same memory bank and therefore shared
2557     // memory bank conflicts. Adding 1 to the minor dimension of the shared
2558     // memory buffer can reduce such shared memory bank conflicts.
2559     llvm::Type* buffer_type = llvm::ArrayType::get(
2560         llvm::ArrayType::get(elem_ty, mapping_scheme.GetTileSizeX() + 1),
2561         mapping_scheme.GetTileSizeY());
2562     return llvm_ir::AllocateSharedMemoryTile(b_.GetInsertBlock()->getModule(),
2563                                              buffer_type, buffer_name);
2564   };
2565 
2566   for (int64 id = 0; id < hlo->operand_count(); id++) {
2567     const HloInstruction* param = hlo->operand(id);
2568     param_arrays.push_back(GetIrArray(*param, *hlo));
2569 
2570     if (absl::c_linear_search(tiled_param_ids, id)) {
2571       param_shmem_buffers[id] =
2572           get_shared_memory_buffer(llvm_ir::PrimitiveTypeToIrType(
2573                                        param->shape().element_type(), module_),
2574                                    IrName(hlo, StrCat("tile", id)));
2575       VLOG(3) << "Added shmem buffer for parameter " << id << ": "
2576               << llvm_ir::DumpToString(*param_shmem_buffers[id]);
2577       Shape reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout(
2578           param->shape().element_type(),
2579           Permute({0, 2, 1}, reduced_output_dims));
2580       param_in_reduced_shape_arrays.push_back(
2581           param_arrays[id].CastToShape(reduced_shape, &b_));
2582     } else {
2583       param_in_reduced_shape_arrays.push_back(IrArray());
2584     }
2585   }
2586 
2587   EmitElementFunction element_generator =
2588       [&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc,
2589           llvm::Value* x_loc, int64 x_iter_num) {
2590         if (hlo->opcode() == HloOpcode::kCopy) {
2591           EmitTileElementForCopy(hlo, index, mapping_scheme, y_loc, x_loc,
2592                                  x_iter_num, param_shmem_buffers);
2593         } else {
2594           CHECK_EQ(hlo->opcode(), HloOpcode::kFusion);
2595           EmitTileElementForFusion(hlo, index, mapping_scheme, y_loc, x_loc,
2596                                    x_iter_num, param_shmem_buffers);
2597         }
2598       };
2599 
2600   TileElementGenerator tile_generator =
2601       [&](llvm::Value* y, llvm::Value* x, const IrArray::Index& index,
2602           const string& loop_name, llvm::Value* tile_height,
2603           llvm::Value* tile_width, KernelSupportLibrary* ksl) {
2604         // If shared memory transpose is needed, wait for all threads to reach
2605         // this point, lest we copy a value from tile to output before the other
2606         // thread copies it from input to tile. This is `__syncthreads` in CUDA.
2607         if (!tiled_param_ids.empty()) {
2608           // Calculate the input tile origin from the output tile origin.
2609           const IrArray::Index input_tile_origin(
2610               Permute({0, 2, 1}, index.multidim()),
2611               Permute({0, 2, 1}, index.dims()), index.GetType());
2612 
2613           // Copy input parameter values to shared memory buffers:
2614           // tile[y, x] = input[index]
2615           // Note that tile_width and tile_height are flipped here because we
2616           // are reading a transposed tile.
2617           EmitTile(mapping_scheme, input_tile_origin, "input", ksl, &b_, y, x,
2618                    tile_width, tile_height,
2619                    [&](const IrArray::Index& index, llvm::Value* y_loc,
2620                        llvm::Value* x_loc, int64 /*x_iter_num*/) {
2621                      for (int64 id : tiled_param_ids) {
2622                        IrArray& input_in_logical_shape =
2623                            param_in_reduced_shape_arrays[id];
2624 
2625                        llvm::Value* shmem_buffer = param_shmem_buffers[id];
2626                        llvm::Value* zero =
2627                            llvm::ConstantInt::get(index_type, 0);
2628                        // TODO(jlebar): Add AA metadata to this store.  Tile
2629                        // buffers are global variables, so LLVM can't infer much
2630                        // about it.
2631                        Store(input_in_logical_shape.EmitReadArrayElement(
2632                                  index, &b_, "input_element"),
2633                              GEP(shmem_buffer, {zero, y_loc, x_loc}));
2634                      }
2635                    });
2636 
2637           // Wait for all threads to reach this point using `__syncthreads` in
2638           // CUDA.
2639           EmitSyncThreads();
2640         }
2641 
2642         EmitTile(mapping_scheme, index, loop_name, ksl, &b_, y, x, tile_height,
2643                  tile_width, element_generator);
2644         bool block_contains_multi_tiles = mapping_scheme.GetTileSizeZ() > 1;
2645 
2646         // If a tile block contains multiple tiles and shared memory buffers are
2647         // used, we need to wait for all threads to finish using the shared
2648         // memory buffer for the current tile before we move on to process the
2649         // next tile and overwrite the shared memory buffers.
2650         if (block_contains_multi_tiles && !tiled_param_ids.empty()) {
2651           EmitSyncThreads();
2652         }
2653       };
2654 
2655   // For multioutput fusion, one thread needs to output a tuple
2656   // with pointers to all the individual outputs.  We could do this
2657   // at any point in the kernel, but we do it at the beginning in
2658   // the hopes of reducing register pressure, since we touch
2659   // threadIdx.x and blockIdx.x at the beginning of the kernel
2660   // *anyway*.
2661   if (hlo->IsMultiOutputFusion()) {
2662     KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] {
2663       llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo),
2664                          ConstructIrArrayForOutputs(*hlo), &b_);
2665     });
2666   }
2667 
2668   EmitTilingKernel(mapping_scheme, index_type, tile_generator);
2669   UpdateLaunchDimensions(launch_dimensions, kernel_thunk,
2670                          ir_emitter_context_->llvm_module());
2671 }
2672 
2673 namespace {
2674 // A recursive function to inspect the users of a parameter to determine
2675 // whether it's safe for a parameter to participate in a shared-memory
2676 // transpose.
2677 //
2678 // Consider a fusion parameter P for which we might want to use a shmem
2679 // transpose.  If we do, we use a GPU thread block to preload a tile of P with
2680 // indices [z, y..y+31, x..x+31] to compute an output tile with the same indices
2681 // cooperatively, where z, y, x are the indices for the normalized input/output
2682 // tensor (see the document for FindTranspose021 for the definition of
2683 // normalized tensor for 0-2-1 transpose). This shmem transpose implementation
2684 // requires that the computation of the output tile only read elements within
2685 // the preload tile. If this is not true, we can't use a shmem transpose for P.
2686 //
2687 // If the computation of output element [z, y, x] only requires the element of
2688 // P with the same indices, the shmem transpose implementation can be applied
2689 // to P safely. This is a sufficient but not necessary condition. We check all
2690 // the transitive users of P to see if we can find a user that may cause an
2691 // exception to the situation. If such a user is not found, we conclude that P
2692 // is safe for shmem transpose.
2693 //
2694 // This is trivially true for elementwise operations and some "data-movement"
2695 // ops like kTuple. However, it's not true for operations that can change the
2696 // dimensions of the inputs (e.g. pad, slice) and bitcast operation.
2697 // For example:
2698 //
2699 // fused_computation {
2700 //   param_0 = f32[64,64]{1,0} parameter(0)
2701 //   ROOT bitcast = f32[64,64]{0,1} bitcast(param_0)
2702 // }
2703 // The output element at logical address [0, 63] depends on the input element
2704 // at logical address [63, 0], which would not be within the shared-memory
2705 // block.
2706 //
2707 // TODO(bixia): In order to extend this for kInput fusion, that is reduction
2708 // with transpose, we only need to end the use-chain checking with the input of
2709 // a reduce operations. In this case, the above description on "output" apply
2710 // to the result of such a use-chain, which provides the input to the reduce
2711 // operation.
IsInstructionSafeForShmemTranspose(const HloInstruction * hlo)2712 bool IsInstructionSafeForShmemTranspose(const HloInstruction* hlo) {
2713   if (hlo->IsElementwise()) {
2714     return absl::c_all_of(hlo->users(), [&](const HloInstruction* user) {
2715       return IsInstructionSafeForShmemTranspose(user);
2716     });
2717   }
2718 
2719   switch (hlo->opcode()) {
2720     // Non-elementwise instructions that don't cause the shmem transpose
2721     // to be unsafe, including the instructions that don't currently fuse.
2722     case HloOpcode::kGetDimensionSize:
2723       // The result of the operation doesn't rely on the content of the
2724       // tensor. As such, there is no need to further inspect its users.
2725       return true;
2726     case HloOpcode::kGetTupleElement:
2727     case HloOpcode::kMap:
2728     case HloOpcode::kParameter:
2729     case HloOpcode::kTuple:
2730     case HloOpcode::kTupleSelect:
2731       return absl::c_all_of(hlo->users(), [&](const HloInstruction* user) {
2732         return IsInstructionSafeForShmemTranspose(user);
2733       });
2734 
2735     default:
2736       return false;
2737   }
2738 }
2739 
2740 // Given a group of input parameters that are 0-2-1 transpose of the outputs of
2741 // a fusion kernel, returns the input parameters that are safe for the shared
2742 // memory transpose implementation.
2743 //
2744 // When a tile based shared memory transpose is used to implement an input with
2745 // 0-2-1 transpose, we preload a tile of the input elements
2746 // [z, y..y+31, x..x+31] to compute the output tile elements of the same
2747 // indices. Preloading the input tile this way is only safe when the computation
2748 // of the output tile elements do not need any input element outside the
2749 // preloaded tile. We inspect all the transitive users of the input parameter
2750 // up to the fusion root instruction to see if we can find any instruction
2751 // that can make preloading the input tile unsafe.
FilterInputsForShmemTranspose(const HloInstruction * fusion,std::vector<int64> input_ids)2752 std::vector<int64> FilterInputsForShmemTranspose(const HloInstruction* fusion,
2753                                                  std::vector<int64> input_ids) {
2754   std::vector<int64> filtered_input_ids;
2755   for (int64 i = 0; i < input_ids.size(); ++i) {
2756     const HloInstruction* input = fusion->fused_parameter(input_ids[i]);
2757     if (IsInstructionSafeForShmemTranspose(input)) {
2758       filtered_input_ids.push_back(input_ids[i]);
2759     } else {
2760       VLOG(10) << "Input not safe for shmem transpose " << input->ToString();
2761     }
2762   }
2763   return filtered_input_ids;
2764 }
2765 
2766 }  // namespace
2767 
CheckAndEmitHloWithTile021(HloInstruction * hlo)2768 bool IrEmitterUnnested::CheckAndEmitHloWithTile021(HloInstruction* hlo) {
2769   HloOpcode opcode = hlo->opcode();
2770 
2771   CHECK(hlo->IsLoopFusion() || opcode == HloOpcode::kCopy);
2772 
2773   const Shape& output_shape = hlo->IsMultiOutputFusion()
2774                                   ? ShapeUtil::GetSubshape(hlo->shape(), {0})
2775                                   : hlo->shape();
2776 
2777   // If the output_shape is reduced to 021 shape, find all the parameters of
2778   // the HLO that are in the corresponding 012 shape.
2779   std::vector<int64> params_012;
2780   optional<std::vector<int64>> reduced_dims_021;
2781   for (int64 operand_idx = 0; operand_idx < hlo->operand_count();
2782        ++operand_idx) {
2783     HloInstruction* operand = hlo->mutable_operand(operand_idx);
2784     auto find_transpose_result =
2785         ShapeUtil::FindTranspose021(operand->shape(), output_shape);
2786     if (!find_transpose_result.has_value()) {
2787       continue;
2788     }
2789     const std::vector<int64>& curr_reduced_dims_021 = *find_transpose_result;
2790     if (!reduced_dims_021.has_value()) {
2791       reduced_dims_021 = curr_reduced_dims_021;
2792     }
2793     if (!absl::c_equal(*reduced_dims_021, curr_reduced_dims_021)) {
2794       // There is more than one possible transpose. Instead of picking one
2795       // transpose, we simply give up here.
2796       return false;
2797     }
2798     params_012.push_back(operand_idx);
2799   }
2800 
2801   if (!reduced_dims_021.has_value()) {
2802     return false;
2803   }
2804 
2805   if ((*reduced_dims_021)[1] < kMinDimensionToTransposeTiled ||
2806       (*reduced_dims_021)[2] < kMinDimensionToTransposeTiled) {
2807     return false;
2808   }
2809 
2810   if (opcode == HloOpcode::kFusion) {
2811     params_012 = FilterInputsForShmemTranspose(hlo, params_012);
2812     if (params_012.empty()) {
2813       return false;
2814     }
2815   }
2816 
2817   // Each of our shared memory tiles has 32*33 elements (so ~4kb, if the
2818   // elements are of size 4 bytes), and CUDA has an architectural limit of
2819   // 48kb shared memory per SM.  (This is increased to 96kb in Volta, but we
2820   // don't use this, in part because it eats into our L1 cache space.)
2821   //
2822   // For correctness we need to ensure that we don't make more than 48kb worth
2823   // of shmem tiles per block.  And for performance, we'd probably like to use
2824   // significantly less, so that we can fit more than one block at a time on a
2825   // gpu core.
2826   //
2827   // We say without benchmarks that we want at least 3 threads/block,
2828   // corresponding to 3 shmem tiles if the elements are 32 bits wide.  We
2829   // choose which params get the shmem transpose treatment arbitrarily; it's
2830   // not clear if there's a Right Choice.
2831   //
2832   // This is only sound if tiled transposes are the only place where we use
2833   // shared memory in fusions.  If in the future other fusible ops use shared
2834   // memory, we'll have to adjust this heuristic.
2835   constexpr int kMinBlocksPerCore = 3;
2836   constexpr int64 kShmemPerCore = 48 * 1024;
2837   int64 shmem_used = 0;
2838   for (int64 i = 0; i < params_012.size(); ++i) {
2839     const HloInstruction* operand = hlo->operand(params_012[i]);
2840     shmem_used +=
2841         32 * 33 *
2842         ShapeUtil::ByteSizeOfPrimitiveType(operand->shape().element_type());
2843 
2844     if (kMinBlocksPerCore * shmem_used > kShmemPerCore) {
2845       // Erase this element and everything after it from params_012.
2846       params_012.resize(i);
2847       break;
2848     }
2849   }
2850 
2851   if (params_012.empty()) {
2852     return false;
2853   }
2854 
2855   VLOG(3) << "EmitHlo021Tile Emitting hlo tile 0-2-1" << hlo->ToString();
2856   std::unique_ptr<KernelThunk> kernel_thunk =
2857       BuildKernelThunk(hlo, /*implements_whole_instruction=*/true);
2858   EmitHlo021Tile(hlo, kernel_thunk.get(), *reduced_dims_021, params_012);
2859   AddThunkToThunkSequence(std::move(kernel_thunk));
2860   return true;
2861 }
2862 
2863 namespace {
2864 
2865 // Returns true if all the transitive users of hlo before hitting users in
2866 // use_chain_endings are elementwise operations.
AreUsersElementwise(const HloInstruction * hlo,const ConstHloInstructionSet & use_chain_endings)2867 bool AreUsersElementwise(const HloInstruction* hlo,
2868                          const ConstHloInstructionSet& use_chain_endings) {
2869   return absl::c_all_of(hlo->users(), [&](const HloInstruction* user) {
2870     return use_chain_endings.count(user) ||
2871            (user->IsElementwise() &&
2872             AreUsersElementwise(user, use_chain_endings));
2873   });
2874 }
2875 
2876 // Returns the number of fusion inputs that have the same dimension as the
2877 // given shape, and involve in only elementwise operations.
NumInputsInvolveInOnlyElementwiseOps(const HloInstruction * unnested_hlo,const Shape & op_shape,const ConstHloInstructionSet & use_chain_endings)2878 int64 NumInputsInvolveInOnlyElementwiseOps(
2879     const HloInstruction* unnested_hlo, const Shape& op_shape,
2880     const ConstHloInstructionSet& use_chain_endings) {
2881   return absl::c_count_if(
2882       unnested_hlo->fused_parameters(), [&](const HloInstruction* parameter) {
2883         const Shape& parameter_shape = parameter->shape();
2884         return ShapeUtil::SameDimensions(op_shape, parameter_shape) &&
2885                AreUsersElementwise(parameter, use_chain_endings);
2886       });
2887 }
2888 
2889 // Returns the number of fusion inputs that have more elements than the given
2890 // shape.
NumInputsWithMoreElementsThan(const HloInstruction * unnested_hlo,const Shape & shape)2891 int64 NumInputsWithMoreElementsThan(const HloInstruction* unnested_hlo,
2892                                     const Shape& shape) {
2893   int64 num_elements = ShapeUtil::ElementsIn(shape);
2894   return absl::c_count_if(
2895       unnested_hlo->fused_parameters(), [&](const HloInstruction* parameter) {
2896         return ShapeUtil::ElementsIn(parameter->shape()) > num_elements;
2897       });
2898 }
2899 
2900 // The benefit of unrolling a kInput fusion that is a column reduction comes
2901 // from the vectorization of non-reduction fusion outputs and fusion inputs.
2902 // On the other hand, unrolling can also introduce factors that can cause
2903 // the kernel to run slower. This routine uses a simple heuristic to estimate
2904 // the benefit as well as the overhead of unrolling in order to decide whether
2905 // unrolling is beneficial for the given kInput fusion.
IsUnrollingColumnReductionBeneficial(const HloInstruction * unnested_hlo,const Shape & input_shape,int64 num_kept_minor)2906 bool IsUnrollingColumnReductionBeneficial(const HloInstruction* unnested_hlo,
2907                                           const Shape& input_shape,
2908                                           int64 num_kept_minor) {
2909   // TODO(b/122468062): Need further investigate to see whether we can
2910   // remove the constraint on IsPowerOfTwo.
2911   if (!IsPowerOfTwo(static_cast<uint64>(num_kept_minor))) {
2912     return false;
2913   }
2914 
2915   if (IsReductionFromOrToContiguousDimensions(*unnested_hlo)) {
2916     return true;
2917   }
2918 
2919   CHECK_EQ(unnested_hlo->opcode(), HloOpcode::kFusion);
2920   int64 can_be_vectorized = 0;
2921   int64 cannot_be_vectorized = 0;
2922   const HloInstruction* fused_root = unnested_hlo->fused_expression_root();
2923   ConstHloInstructionSet use_chain_endings;
2924   if (IsReductionFromOrToContiguousDimensions(*fused_root)) {
2925     use_chain_endings.insert(fused_root);
2926     // Atomic.add of the reduction result can't be vectorized.
2927     cannot_be_vectorized++;
2928   } else {
2929     CHECK_EQ(fused_root->opcode(), HloOpcode::kTuple);
2930     for (const HloInstruction* instr : fused_root->operands()) {
2931       if (IsReductionFromOrToContiguousDimensions(*instr)) {
2932         // Atomic.add of the reduction result can't be vectorized.
2933         cannot_be_vectorized++;
2934       } else {
2935         // Write of the non-reduction result can be vectorized.
2936         can_be_vectorized++;
2937       }
2938       use_chain_endings.insert(instr);
2939     }
2940   }
2941   // Fusion inputs that have the same dimension as the reduce input and
2942   // only involve in elementwise operations can be vectorized.
2943   can_be_vectorized += NumInputsInvolveInOnlyElementwiseOps(
2944       unnested_hlo, input_shape, use_chain_endings);
2945   // Fusion inputs with more elements than the reduce op input must participate
2946   // in non-elementwise operations and we assume that they are not vectorizable
2947   // for the purpose of estimating the benefit of unrolling. If the kernel is
2948   // unrolled even with such an assumption,  and the accesses to those inputs
2949   // turn out to be vectorizable, the compiler will still vectorize them.
2950   cannot_be_vectorized +=
2951       NumInputsWithMoreElementsThan(unnested_hlo, input_shape);
2952   return can_be_vectorized >= cannot_be_vectorized;
2953 }
2954 
2955 }  // namespace
2956 
ComputeReductionCodegenInfo(const HloInstruction * unnested_hlo,const HloInstruction * first_reduce)2957 ReductionCodegenInfo IrEmitterUnnested::ComputeReductionCodegenInfo(
2958     const HloInstruction* unnested_hlo, const HloInstruction* first_reduce) {
2959   const Shape& input_shape = first_reduce->operand(0)->shape();
2960   ReductionDimensions reduction_dimensions =
2961       GetReductionKindAndContiguousComponents(*first_reduce);
2962   VLOG(10) << "is_row_reduction " << reduction_dimensions.is_row_reduction
2963            << " " << reduction_dimensions.dimensions[0] << " "
2964            << reduction_dimensions.dimensions[1] << " "
2965            << reduction_dimensions.dimensions[2];
2966 
2967   std::array<int64, 3> reduction_tiling =
2968       GetReductionTiling(reduction_dimensions);
2969   bool dilated_x =
2970       reduction_dimensions.is_row_reduction ||
2971       !IsUnrollingColumnReductionBeneficial(unnested_hlo, input_shape,
2972                                             reduction_dimensions.dimensions[2]);
2973 
2974   if (!dilated_x && !reduction_dimensions.is_row_reduction) {
2975     // Vectorized loads: a single thread reduces two adjacent columns.
2976     reduction_tiling[2] *= 2;
2977   }
2978 
2979   int64 num_threads_y = 1;
2980   int64 num_threads_x = [&] {
2981     if (reduction_dimensions.is_row_reduction) {
2982       return kWarpSize;
2983     }
2984     return std::min(
2985         ThreadsPerBlockLimit(ir_emitter_context_->device_description()),
2986         CeilOfRatio(reduction_dimensions.dimensions[2], reduction_tiling[2]));
2987   }();
2988 
2989   KernelMappingScheme mapping_scheme(
2990       reduction_dimensions.dimensions,
2991       {reduction_tiling[0], reduction_tiling[1] * num_threads_y,
2992        reduction_tiling[2] * num_threads_x},
2993       num_threads_y, num_threads_x, dilated_x);
2994   return ReductionCodegenInfo(mapping_scheme,
2995                               reduction_dimensions.is_row_reduction);
2996 }
2997 
EmitReductionFromOrToContiguousDimensions(HloInstruction * unnested_hlo,absl::Span<HloInstruction * const> output_instructions)2998 Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
2999     HloInstruction* unnested_hlo,
3000     absl::Span<HloInstruction* const> output_instructions) {
3001   bool returns_tuple = output_instructions.size() > 1;
3002   VLOG(10) << "Emitting reduction to vector " << unnested_hlo->ToString();
3003 
3004   std::vector<HloInstruction*> reduce_instructions;
3005   InlinedVector<ShapeIndex, 1> reduction_output_shape_indices;
3006   InlinedVector<HloComputation*, 1> reducers;
3007 
3008   // Build an initializer thunk to initialize each reduction output.
3009   std::vector<std::unique_ptr<Thunk>> thunks;
3010   for (int i = 0; i < output_instructions.size(); ++i) {
3011     if (!IsReductionFromOrToContiguousDimensions(*output_instructions[i])) {
3012       continue;
3013     }
3014 
3015     HloInstruction* output_instruction = output_instructions[i];
3016     reduce_instructions.push_back(output_instruction);
3017     ShapeIndex idx = returns_tuple ? ShapeIndex({i}) : ShapeIndex({});
3018     reduction_output_shape_indices.push_back(idx);
3019     reducers.push_back(output_instruction->to_apply());
3020 
3021     TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> initializer_thunk,
3022                         BuildInitializerThunk(unnested_hlo, idx));
3023     thunks.push_back(std::move(initializer_thunk));
3024   }
3025 
3026   const HloInstruction* first_reduce = reduce_instructions.at(0);
3027   if (output_instructions.size() > 1) {
3028     if (!AreFusedReductionOutputsConsistent(output_instructions,
3029                                             first_reduce)) {
3030       return InternalError("Inconsistent reduction fusion outputs");
3031     }
3032   }
3033 
3034   // Build a kernel thunk to compute all the outputs.
3035   std::unique_ptr<KernelThunk> kernel_thunk =
3036       BuildKernelThunk(unnested_hlo, /*implements_whole_instruction=*/false);
3037 
3038   const Shape& input_shape = first_reduce->operand(0)->shape();
3039   // The layout of a reduction input is either set by LayoutAssignment for
3040   // unnested kReduce or by InstructionFusion for fused kReduce.
3041   CHECK(input_shape.has_layout()) << "LayoutAssignment or InstructionFusion "
3042                                      "doesn't set the input layout of "
3043                                   << first_reduce->ToString();
3044 
3045   ReductionCodegenInfo reduction_info =
3046       ComputeReductionCodegenInfo(unnested_hlo, first_reduce);
3047   const KernelMappingScheme& mapping_scheme =
3048       reduction_info.GetKernelMappingScheme();
3049   LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(),
3050                                      mapping_scheme.GetThreadsPerBlock());
3051   llvm::Type* index_ty = GetIndexTypeForKernel(
3052       unnested_hlo, launch_dimensions.launch_bound(), &b_);
3053   EmitPrologueForReduction(unnested_hlo, &reduction_info, reduce_instructions,
3054                            index_ty);
3055   EmitElementFunction emit_reduction_tile =
3056       [&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc,
3057           llvm::Value* x_loc, int64 x_iter_num) {
3058         EmitTileElementForReduction(unnested_hlo, input_shape,
3059                                     output_instructions, index, reduction_info,
3060                                     reducers, x_iter_num);
3061       };
3062 
3063   TilingKernelInfo tiling_kernel_info = EmitTilingKernel(
3064       mapping_scheme, index_ty,
3065       [&](llvm::Value* y, llvm::Value* x, const IrArray::Index& index,
3066           const string& loop_name, llvm::Value* tile_height,
3067           llvm::Value* tile_width, KernelSupportLibrary* ksl) {
3068         EmitTile(reduction_info.GetKernelMappingScheme(), index, loop_name, ksl,
3069                  &b_, y, x, tile_height, tile_width, emit_reduction_tile);
3070       });
3071   EmitEpilogueForReduction(index_ty, unnested_hlo, reduction_info,
3072                            reduce_instructions, reduction_output_shape_indices,
3073                            reducers, tiling_kernel_info);
3074 
3075   UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
3076                          ir_emitter_context_->llvm_module());
3077 
3078   thunks.push_back(std::move(kernel_thunk));
3079   auto sequential_thunk =
3080       absl::make_unique<SequentialThunk>(std::move(thunks), unnested_hlo);
3081   AddThunkToThunkSequence(std::move(sequential_thunk));
3082 
3083   return Status::OK();
3084 }
3085 
EmitConstantGlobals()3086 Status IrEmitterUnnested::EmitConstantGlobals() {
3087   for (const BufferAllocation& allocation :
3088        ir_emitter_context_->buffer_assignment().Allocations()) {
3089     if (!allocation.is_constant()) {
3090       continue;
3091     }
3092 
3093     const Literal& literal = llvm_ir::LiteralForConstantAllocation(allocation);
3094     const bool should_emit_initializer = ShouldEmitLiteralInLlvmIr(literal);
3095     llvm::ArrayType* global_type =
3096         llvm::ArrayType::get(b_.getInt8Ty(), allocation.size());
3097     llvm::Constant* initializer =
3098         should_emit_initializer
3099             ? llvm_ir::ConvertLiteralToIrConstant(literal, module_)
3100             : llvm::ConstantAggregateZero::get(global_type);
3101     if (should_emit_initializer) {
3102       VLOG(3) << "Emitted initializer for constant with shape "
3103               << ShapeUtil::HumanString(literal.shape());
3104     }
3105 
3106     // These globals will be looked up by name by GpuExecutable so we need to
3107     // give them an external linkage.  Not all of their uses are visible in
3108     // the LLVM IR (e.g. TupleThunk) so we can't give then a linkage that
3109     // merely preserves their names (like available_externally), we also need
3110     // to ensure that they stick around even if they're "unused".
3111     //
3112     // We may have to be more more clever here in the future if we notice that
3113     // we're keeping around too many globals because of their linkage.
3114     unsigned global_address_space = llvm_ir::GetGlobalMemoryAddressSpace(
3115         *ir_emitter_context_->llvm_module());
3116     llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable(
3117         global_type, /*isConstant=*/should_emit_initializer,
3118         llvm::GlobalValue::ExternalLinkage,
3119         /*Initializer=*/initializer,
3120         llvm_ir::ConstantBufferAllocationToGlobalName(allocation),
3121         /*TLMode=*/llvm::GlobalValue::NotThreadLocal,
3122         /*AddressSpace=*/global_address_space,
3123         /*isExternallyInitialized=*/false);
3124     global_for_const->setAlignment(kConstantBufferAlignBytes);
3125     ir_emitter_context_->llvm_module()->getGlobalList().push_back(
3126         global_for_const);
3127   }
3128 
3129   return Status::OK();
3130 }
3131 
3132 // Emits code for slices based on the below structure. An if statement with
3133 // a guarding condition is generated for each ROOT slice.
3134 //
3135 // Pseudo code:
3136 //
3137 // Compute values of slice input operands
3138 //
3139 // Compute guarding_cond0
3140 // if (guarding_cond0) {
3141 //   Write to output of slice0
3142 // }
3143 //
3144 // Compute guarding_cond1
3145 // if (guarding_cond1) {
3146 //   Write to output of slice1
3147 // }
3148 //
EmitElementForInputFusibleSlices(HloInstruction * unnested_hlo,const llvm_ir::IrArray::Index & index)3149 void IrEmitterUnnested::EmitElementForInputFusibleSlices(
3150     HloInstruction* unnested_hlo, const llvm_ir::IrArray::Index& index) {
3151   VLOG(10) << "Emitting slice input fusion for " << unnested_hlo->ToString();
3152 
3153   HloInstruction* slice_or_tuple = unnested_hlo->fused_expression_root();
3154   auto slice_instructions = [&]() -> absl::Span<HloInstruction* const> {
3155     if (slice_or_tuple->opcode() == HloOpcode::kSlice) {
3156       return absl::Span<HloInstruction* const>(&slice_or_tuple, 1);
3157     }
3158     CHECK_EQ(slice_or_tuple->opcode(), HloOpcode::kTuple);
3159     return slice_or_tuple->operands();
3160   }();
3161 
3162   // Emit input operand values of slices.
3163   std::vector<llvm::Value*> input_ir_values;
3164   GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
3165                                      GetNestedComputer());
3166   FusedIrEmitter fused_emitter(GetGeneratorForOperandIrArrays(unnested_hlo),
3167                                &elem_emitter);
3168   TF_CHECK_OK(unnested_hlo->fused_expression_root()->Accept(&fused_emitter));
3169   for (const HloInstruction* slice : slice_instructions) {
3170     auto input_generator = fused_emitter.GetGenerator(slice->operand(0));
3171     input_ir_values.push_back(input_generator(index).ValueOrDie());
3172   }
3173 
3174   // Emit for slice_instructions.
3175   KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
3176   for (int64 i = 0; i < slice_instructions.size(); ++i) {
3177     HloInstruction* slice = slice_instructions[i];
3178 
3179     // guarding_cond := index >= start && index < limit, for each dim.
3180     std::vector<llvm::Value*> index_within_ranges;
3181     for (size_t dim = 0; dim < slice->slice_starts().size(); ++dim) {
3182       CHECK_EQ(slice->slice_strides(dim), 1);
3183       auto larger_or_equal_than_start = b_.CreateICmpSGE(
3184           index.multidim()[dim],
3185           index.GetConstantWithIndexType(slice->slice_starts(dim)));
3186       llvm::Value* smaller_than_limit = b_.CreateICmpSLT(
3187           index.multidim()[dim],
3188           index.GetConstantWithIndexType(slice->slice_limits(dim)));
3189       llvm::Value* within_range =
3190           b_.CreateAnd(larger_or_equal_than_start, smaller_than_limit);
3191       index_within_ranges.push_back(within_range);
3192     }
3193     llvm::Value* guarding_cond = b_.CreateAnd(index_within_ranges);
3194 
3195     auto emit_slice_elem_func = [&] {
3196       const std::vector<llvm::Value*>& src_multidim = index.multidim();
3197       std::vector<llvm::Value*> dst_multidim(src_multidim.size());
3198       for (size_t dim = 0; dim < src_multidim.size(); ++dim) {
3199         dst_multidim[dim] =
3200             Sub(src_multidim[dim],
3201                 index.GetConstantWithIndexType(slice->slice_starts(dim)));
3202       }
3203       ShapeIndex shape_index = (slice_or_tuple->opcode() == HloOpcode::kSlice)
3204                                    ? ShapeIndex()
3205                                    : ShapeIndex({i});
3206       llvm_ir::IrArray src_ir_array =
3207           GetIrArray(*unnested_hlo, *unnested_hlo, shape_index);
3208       IrArray::Index slice_dst_index(dst_multidim, slice->shape(),
3209                                      index.GetType());
3210       llvm::Value* dst_addr = src_ir_array.EmitArrayElementAddress(
3211           slice_dst_index, &b_, "slice.dest");
3212       b_.CreateStore(input_ir_values[i], dst_addr);
3213     };
3214 
3215     ksl.If(StrCat("slice", i), guarding_cond, emit_slice_elem_func);
3216   }
3217 }
3218 
EmitInputFusibleNonStridedSlices(HloInstruction * unnested_hlo)3219 Status IrEmitterUnnested::EmitInputFusibleNonStridedSlices(
3220     HloInstruction* unnested_hlo) {
3221   constexpr int unroll_factor = 1;
3222   std::unique_ptr<KernelThunk> kernel_thunk = BuildKernelThunk(
3223       unnested_hlo, /*implements_whole_instruction=*/true, unroll_factor);
3224 
3225   TF_ASSIGN_OR_RETURN(Shape element_shape,
3226                       GetConsistentInputShapeForRootSlices(*unnested_hlo));
3227   LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
3228       element_shape, ir_emitter_context_->device_description(), unroll_factor);
3229   UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
3230                          ir_emitter_context_->llvm_module());
3231 
3232   Status emit_status =
3233       ParallelLoopEmitter(
3234           [&](const llvm_ir::IrArray::Index index) -> Status {
3235             EmitElementForInputFusibleSlices(unnested_hlo, index);
3236             return Status::OK();
3237           },
3238           element_shape, launch_dimensions, &b_)
3239           .EmitLoop(IrName(unnested_hlo),
3240                     GetIndexTypeForKernel(
3241                         unnested_hlo, launch_dimensions.launch_bound(), &b_));
3242 
3243   thunk_sequence_->emplace_back(std::move(kernel_thunk));
3244 
3245   return emit_status;
3246 }
3247 
3248 }  // namespace gpu
3249 }  // namespace xla
3250