• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
17 
18 #include <algorithm>
19 #include <cstring>
20 #include <iterator>
21 #include <memory>
22 #include <string>
23 #include <type_traits>
24 #include <vector>
25 
26 #include "absl/algorithm/container.h"
27 #include "absl/container/inlined_vector.h"
28 #include "absl/memory/memory.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/strings/str_format.h"
31 #include "absl/types/optional.h"
32 #include "absl/types/span.h"
33 #include "llvm/ADT/APInt.h"
34 #include "llvm/ADT/SetVector.h"
35 #include "llvm/ADT/StringRef.h"
36 #include "llvm/IR/BasicBlock.h"
37 #include "llvm/IR/Function.h"
38 #include "llvm/IR/IRBuilder.h"
39 #include "llvm/IR/Instructions.h"
40 #include "llvm/IR/LLVMContext.h"
41 #include "llvm/IR/Module.h"
42 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
43 #include "mlir/IR/Attributes.h"  // from @llvm-project
44 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
45 #include "mlir/IR/Builders.h"  // from @llvm-project
46 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
47 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
48 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
49 #include "mlir/IR/Verifier.h"  // from @llvm-project
50 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
51 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
52 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h"
53 #include "tensorflow/compiler/mlir/utils/name_utils.h"
54 #include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
55 #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
56 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
57 #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
58 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
59 #include "tensorflow/compiler/xla/layout_util.h"
60 #include "tensorflow/compiler/xla/literal.h"
61 #include "tensorflow/compiler/xla/permutation_util.h"
62 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
63 #include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
64 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
65 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
66 #include "tensorflow/compiler/xla/service/gpu/bef_thunk.h"
67 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
68 #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h"
69 #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
70 #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h"
71 #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h"
72 #include "tensorflow/compiler/xla/service/gpu/custom_call_thunk.h"
73 #include "tensorflow/compiler/xla/service/gpu/fft_thunk.h"
74 #include "tensorflow/compiler/xla/service/gpu/for_thunk.h"
75 #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h"
76 #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
77 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"
78 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
79 #include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h"
80 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
81 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
82 #include "tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h"
83 #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h"
84 #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
85 #include "tensorflow/compiler/xla/service/gpu/memset_thunk.h"
86 #include "tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.h"
87 #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
88 #include "tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.h"
89 #include "tensorflow/compiler/xla/service/gpu/nccl_collective_permute_thunk.h"
90 #include "tensorflow/compiler/xla/service/gpu/outfeed_thunk.h"
91 #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
92 #include "tensorflow/compiler/xla/service/gpu/replica_id_thunk.h"
93 #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
94 #include "tensorflow/compiler/xla/service/gpu/target_util.h"
95 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
96 #include "tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h"
97 #include "tensorflow/compiler/xla/service/gpu/while_thunk.h"
98 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
99 #include "tensorflow/compiler/xla/service/hlo_computation.h"
100 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
101 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
102 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
103 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
104 #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
105 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
106 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
107 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
108 #include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h"
109 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
110 #include "tensorflow/compiler/xla/service/name_uniquer.h"
111 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
112 #include "tensorflow/compiler/xla/service/shape_inference.h"
113 #include "tensorflow/compiler/xla/service/while_loop_analysis.h"
114 #include "tensorflow/compiler/xla/shape_util.h"
115 #include "tensorflow/compiler/xla/status_macros.h"
116 #include "tensorflow/compiler/xla/types.h"
117 #include "tensorflow/compiler/xla/union_find.h"
118 #include "tensorflow/compiler/xla/util.h"
119 #include "tensorflow/compiler/xla/window_util.h"
120 #include "tensorflow/compiler/xla/xla_data.pb.h"
121 #include "tensorflow/core/lib/core/bits.h"
122 #include "tensorflow/core/lib/core/status.h"
123 #include "tensorflow/core/platform/errors.h"
124 #include "tensorflow/core/platform/logging.h"
125 
126 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
127 #include "tensorflow/compiler/xla/service/gpu/cholesky_thunk.h"
128 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
129 
130 namespace xla {
131 namespace gpu {
132 
133 namespace {
134 
135 using absl::InlinedVector;
136 using absl::nullopt;
137 using absl::optional;
138 using absl::StrCat;
139 using llvm_ir::IrArray;
140 using llvm_ir::IrName;
141 
142 const auto kDimX = KernelMappingScheme::DimX;
143 const auto kDimY = KernelMappingScheme::DimY;
144 const auto kDimZ = KernelMappingScheme::DimZ;
145 const auto kDimTot = KernelMappingScheme::DimTot;
146 
147 const auto kLinearIndexingX = KernelMappingScheme::LinearIndexingX;
148 const auto kStridedIndexingX = KernelMappingScheme::StridedIndexingX;
149 const auto kStridedLinearIndexingX =
150     KernelMappingScheme::StridedLinearIndexingX;
151 
152 // If a dimensions is smaller than this, untiled transposition may be more
153 // efficient.
154 const int64_t kMinDimensionToTransposeTiled = 16;
155 
156 // Annotates the launch dimensions of the corresponding IR kernel in
157 // `llvm_module`.
AnnotateThunkLaunchDimensions(const LaunchDimensions & launch_dims,const std::string & kernel_name,llvm::Module * llvm_module)158 void AnnotateThunkLaunchDimensions(const LaunchDimensions& launch_dims,
159                                    const std::string& kernel_name,
160                                    llvm::Module* llvm_module) {
161   // Add __launch_bounds__ to metadata. This limits registers per thread to
162   // avoid out-of-resources launching errors.
163   llvm::NamedMDNode* nvvm_annotations_node =
164       llvm_module->getOrInsertNamedMetadata("nvvm.annotations");
165   llvm::Function* ir_kernel = llvm_module->getFunction(kernel_name.c_str());
166   llvm::LLVMContext& llvm_context = llvm_module->getContext();
167   llvm::ConstantInt* threads_per_block_ir_value = llvm::ConstantInt::get(
168       llvm::IntegerType::get(llvm_context, /*NumBits=*/32),
169       launch_dims.thread_counts_per_block().x);
170   // Our launch bounds are exact, so we can specify them as reqntidx rather than
171   // maxntidx.
172   nvvm_annotations_node->addOperand(llvm::MDNode::get(
173       llvm_context,
174       {llvm::ConstantAsMetadata::get(ir_kernel),
175        llvm::MDString::get(llvm_context, "reqntidx"),
176        llvm::ConstantAsMetadata::get(threads_per_block_ir_value)}));
177 }
178 
BinarySearchDenseElementsAttr(mlir::DenseIntElementsAttr elements,int64_t v)179 bool BinarySearchDenseElementsAttr(mlir::DenseIntElementsAttr elements,
180                                    int64_t v) {
181   mlir::APInt value(sizeof(int64) * 8, v, /*isSigned=*/true);
182   return std::binary_search(
183       elements.begin(), elements.end(), value,
184       [](const mlir::APInt& x, const mlir::APInt& y) { return x.slt(y); });
185 }
186 
MhloOpIsElementwise(mlir::Operation * op)187 bool MhloOpIsElementwise(mlir::Operation* op) {
188   CHECK(op->getDialect() ==
189         op->getContext()->getLoadedDialect<mlir::mhlo::MhloDialect>());
190   auto opcode = *MhloToHloOpcode(op);
191   if (HloInstruction::IsOpElementwise(opcode)) {
192     return true;
193   }
194   if (opcode == HloOpcode::kMap) {
195     int iota = 0;
196     for (const llvm::APInt& i :
197          mlir::cast<mlir::mhlo::MapOp>(op).dimensions()) {
198       if (i.getZExtValue() != iota) {
199         return false;
200       }
201       iota++;
202     }
203     return true;
204   }
205   // TODO(timshen): not sure about whether porting
206   // HloFusionInstruction::IsElementwiseImpl() is necessary. HandleFusion()
207   // doesn't use such information.
208   return false;
209 }
210 
IsSingleInstructionFusion(mlir::lmhlo::FusionOp fusion)211 bool IsSingleInstructionFusion(mlir::lmhlo::FusionOp fusion) {
212   int instruction_count = 0;
213   for (mlir::Operation& instr : fusion.region().front()) {
214     if (mlir::isa<mlir::lmhlo::TerminatorOp, mlir::mhlo::ReturnOp,
215                   mlir::memref::TensorLoadOp, mlir::memref::TensorStoreOp>(
216             &instr)) {
217       continue;
218     }
219     instruction_count++;
220   }
221   return instruction_count == 1;
222 }
223 
MayPreventVectorization(mlir::Operation * op)224 bool MayPreventVectorization(mlir::Operation* op) {
225   // An empirically chosen constant: unrolling concat with a large amount of
226   // arguments causes excessive register spilling.
227   static constexpr int kMaxConcatArgumentsForUnrolling = 10;
228 
229   auto fusion = mlir::cast<mlir::lmhlo::FusionOp>(op);
230   const bool is_single_instruction = IsSingleInstructionFusion(fusion);
231 
232   for (mlir::Operation& instr : fusion.region().front()) {
233     if (mlir::isa<mlir::lmhlo::TerminatorOp, mlir::mhlo::ReturnOp,
234                   mlir::memref::TensorLoadOp, mlir::memref::TensorStoreOp>(
235             &instr)) {
236       continue;
237     }
238     if (is_single_instruction) {
239       auto instr_opcode = *MhloToHloOpcode(&instr);
240       if (MhloOpIsElementwise(&instr)) {
241         switch (instr_opcode) {
242           case HloOpcode::kSin:
243           case HloOpcode::kCos:
244           case HloOpcode::kPower:
245           case HloOpcode::kAtan2:
246             return true;
247           default:
248             return false;
249         }
250       } else if (instr_opcode == HloOpcode::kReduce &&
251                  instr.getNumResults() == 1) {
252         // TODO(timshen): check if the to_apply() attribute contains
253         // instructions that break LLVM vectorization.
254         return false;
255       }
256       return true;
257     }
258 
259     CHECK(instr.getDialect() ==
260           instr.getContext()->getLoadedDialect<mlir::mhlo::MhloDialect>())
261         << MlirToString(op);
262     switch (*MhloToHloOpcode(&instr)) {
263       case HloOpcode::kReduceWindow:
264       case HloOpcode::kSort:
265       case HloOpcode::kDot:
266       case HloOpcode::kSin:
267       case HloOpcode::kCos:
268       case HloOpcode::kPower:
269       case HloOpcode::kAtan2:
270         return true;
271       case HloOpcode::kConcatenate:
272         if (instr.getOperands().size() > kMaxConcatArgumentsForUnrolling) {
273           return true;
274         }
275         break;
276       case HloOpcode::kReduce:
277         if (instr.getNumResults() > 1) {
278           return true;
279         }
280         break;
281       default:
282         break;
283     }
284   }
285   return false;
286 }
287 
GetOutputOps(mlir::lmhlo::FusionOp fusion)288 std::vector<mlir::Operation*> GetOutputOps(mlir::lmhlo::FusionOp fusion) {
289   llvm::SetVector<mlir::Operation*> ops;
290   for (mlir::Value output_value : fusion.getFusionResults()) {
291     ops.insert(output_value.getDefiningOp());
292   }
293   return std::vector<mlir::Operation*>(ops.begin(), ops.end());
294 }
295 
296 // Computes the maximum valid unroll factor for a given instruction.
ComputeMaxUnrollFactor(mlir::Type type,const HloModuleConfig & hlo_module_config)297 int ComputeMaxUnrollFactor(mlir::Type type,
298                            const HloModuleConfig& hlo_module_config) {
299   int max_unroll_factor =
300       hlo_module_config.debug_options().xla_gpu_max_kernel_unroll_factor();
301 
302   // Find the largest possible power of two to unroll by.
303   // TODO(kramerb): Make this smarter.
304 
305   auto shaped_type = type.cast<mlir::ShapedType>();
306   int64_t num_elements = std::accumulate(shaped_type.getShape().begin(),
307                                          shaped_type.getShape().end(), int64{1},
308                                          std::multiplies<int64>());
309   for (int i = max_unroll_factor; i > 1; i /= 2) {
310     if (num_elements % i == 0) {
311       return i;
312     }
313   }
314 
315   // Cannot unroll.
316   return 1;
317 }
318 
319 // Computes the maximum valid unroll factor for a given instruction.
ComputeMaxUnrollFactor(mlir::Operation * op,const HloModuleConfig & hlo_module_config)320 int ComputeMaxUnrollFactor(mlir::Operation* op,
321                            const HloModuleConfig& hlo_module_config) {
322   mlir::Type element_shape = [&] {
323     std::vector<mlir::Type> shapes;
324     // Detect multi-output fusion. Notice that for a reduce in the fusion that
325     // returns a tuple, we don't want to treat it as multi-output fusion. We
326     // want to pass that tuple into ComputeMaxUnrollFactor below. For an actual
327     // MOF, just pass the first element of the root tuple.
328     if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
329       std::vector<mlir::Operation*> fusion_outputs = GetOutputOps(fusion);
330       for (mlir::Value result : fusion_outputs[0]->getResults()) {
331         return result.getType();
332       }
333     } else {
334       for (mlir::Value result : GetHloOutputs(op)) {
335         return result.getType();
336       }
337     }
338     CHECK(false);
339   }();
340   return ComputeMaxUnrollFactor(element_shape, hlo_module_config);
341 }
342 
343 // Returns the llvm type for the indices used in the kernel that contains the
344 // hlo instruction. Such indices include the index for the parallel loop and
345 // the indices for the tensors accessed by the kernel. The return type is i32
346 // iff the following conditions are met:
347 //  . The launch_size of the kernel is within the range of i32.
348 //  . The sizes of all the tensors accessed within the kernel are within the
349 //    range of i32.
350 // Otherwise, the return type is i64.
GetIndexTypeForKernel(const HloInstruction * hlo,int64_t launch_size,llvm::IRBuilder<> * b)351 llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo,
352                                   int64_t launch_size, llvm::IRBuilder<>* b) {
353   // Find the unnested hlo instruction for which the kernel is generated for.
354   const HloInstruction* unnested_hlo = hlo;
355   const HloComputation* computation = hlo->parent();
356   if (computation->IsFusionComputation()) {
357     unnested_hlo = computation->FusionInstruction();
358   }
359 
360   auto shape_in_range = [&](const Shape& s) {
361     bool in_range = true;
362     ShapeUtil::ForEachSubshape(s, [&](const Shape& sub_shape,
363                                       const ShapeIndex& /*index*/) {
364       if (sub_shape.IsArray() && !IsInt32(ShapeUtil::ElementsIn(sub_shape))) {
365         in_range = false;
366       }
367     });
368 
369     return in_range;
370   };
371 
372   llvm::Type* i64_ty = b->getInt64Ty();
373   // Check launch dimension
374   if (!IsInt32(launch_size)) {
375     return i64_ty;
376   }
377 
378   // Check the size of result tensors
379   if (!shape_in_range(unnested_hlo->shape())) {
380     return i64_ty;
381   }
382 
383   auto hlo_shape_in_range = [&](const HloInstruction* operand) -> bool {
384     return shape_in_range(operand->shape());
385   };
386 
387   // Check the size of input tensors
388   if (!absl::c_all_of(unnested_hlo->operands(), hlo_shape_in_range)) {
389     return i64_ty;
390   }
391 
392   // Check the size of the internal result tensors
393   if (unnested_hlo->opcode() == HloOpcode::kFusion) {
394     if (!absl::c_all_of(
395             unnested_hlo->fused_instructions_computation()->instructions(),
396             hlo_shape_in_range)) {
397       return i64_ty;
398     }
399   }
400 
401   return b->getInt32Ty();
402 }
403 
404 // The same as GetIndexTypeForKernel, but works with MLIR ops.
GetIndexTypeForKernel(mlir::Operation * op,int64_t launch_size,llvm::IRBuilder<> * b)405 llvm::Type* GetIndexTypeForKernel(mlir::Operation* op, int64_t launch_size,
406                                   llvm::IRBuilder<>* b) {
407   auto shape_in_range = [&](const Shape& s) {
408     bool in_range = true;
409     ShapeUtil::ForEachSubshape(s, [&](const Shape& sub_shape,
410                                       const ShapeIndex& /*index*/) {
411       if (sub_shape.IsArray() && !IsInt32(ShapeUtil::ElementsIn(sub_shape))) {
412         in_range = false;
413       }
414     });
415 
416     return in_range;
417   };
418 
419   llvm::Type* i64_ty = b->getInt64Ty();
420   // Check launch dimension
421   if (!IsInt32(launch_size)) {
422     return i64_ty;
423   }
424 
425   // Check the size of result tensors
426   for (auto result : GetHloOutputs(op)) {
427     if (!shape_in_range(GetShape(result))) {
428       return i64_ty;
429     }
430   }
431 
432   auto hlo_shape_in_range = [&](mlir::Value operand) -> bool {
433     return shape_in_range(GetShape(operand));
434   };
435 
436   // Check the size of input tensors
437   if (!absl::c_all_of(op->getOperands(), hlo_shape_in_range)) {
438     return i64_ty;
439   }
440 
441   // Check the size of the internal result tensors
442   if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
443     auto result = fusion.region().walk([&](mlir::Operation* op) {
444       for (mlir::Value result : op->getResults()) {
445         if (!hlo_shape_in_range(result)) {
446           return mlir::WalkResult::interrupt();
447         }
448       }
449       return mlir::WalkResult::advance();
450     });
451     if (result.wasInterrupted()) {
452       return i64_ty;
453     }
454   }
455 
456   return b->getInt32Ty();
457 }
458 
459 // Gets the input shape of the ROOT slices, which will be used as the kernel
460 // launch dims. The slice input fusion requires the input shapes of the ROOT
461 // slices to be the same although the (slice) output shapes can be different.
462 //
463 // Returns the input shape of the ROOT slices if all the input shapes of ROOT
464 // slices are the same and the slices are non-strided. Otherwise, returns
465 // FailedPrecondition.
GetConsistentInputShapeForRootSlices(const HloComputation * fused_computation)466 StatusOr<Shape> GetConsistentInputShapeForRootSlices(
467     const HloComputation* fused_computation) {
468   const HloInstruction& root = *fused_computation->root_instruction();
469   if (root.opcode() == HloOpcode::kSlice) {
470     return root.operands()[0]->shape();
471   }
472 
473   CHECK_EQ(root.opcode(), HloOpcode::kTuple);
474   const Shape& first_slice_operand_shape =
475       root.operands()[0]->operands()[0]->shape();
476   for (size_t i = 1; i < root.operands().size(); ++i) {
477     const HloInstruction* slice = root.operands()[i];
478     const Shape& operand_shape = slice->operands()[0]->shape();
479     if (!ShapeUtil::EqualIgnoringElementType(first_slice_operand_shape,
480                                              operand_shape)) {
481       return FailedPrecondition(
482           "Fused slices do not have the same input shape, fused computation = "
483           "%s.",
484           root.parent()->name());
485     }
486   }
487 
488   return first_slice_operand_shape;
489 }
490 
491 }  // namespace
492 
IrEmitterUnnested(const HloModuleConfig & hlo_module_config,IrEmitterContext * ir_emitter_context)493 IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
494                                      IrEmitterContext* ir_emitter_context)
495     : IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/false) {}
496 
Create(const HloModuleConfig & hlo_module_config,IrEmitterContext * ir_emitter_context)497 StatusOr<std::unique_ptr<IrEmitterUnnested>> IrEmitterUnnested::Create(
498     const HloModuleConfig& hlo_module_config,
499     IrEmitterContext* ir_emitter_context) {
500   return std::unique_ptr<IrEmitterUnnested>(
501       new IrEmitterUnnested(hlo_module_config, ir_emitter_context));
502 }
503 
BuildKernelPrototype(absl::string_view name,absl::Span<const BufferAllocation * const> args)504 llvm::Function* IrEmitterUnnested::BuildKernelPrototype(
505     absl::string_view name, absl::Span<const BufferAllocation* const> args) {
506   // Compute the kernel name. The opcode string may contain "-" which cannot be
507   // in a PTX function name, so sanitize the name before uniquifying it.
508   string kernel_name = ir_emitter_context_->name_uniquer()->GetUniqueName(
509       llvm_ir::SanitizeFunctionName(std::string(name)));
510 
511   // Create the kernel and add it to the module.
512   llvm::Module* module = ir_emitter_context_->llvm_module();
513   llvm::LLVMContext& context = module->getContext();
514   llvm::FunctionType* kernel_type = llvm::FunctionType::get(
515       /*Result=*/llvm::Type::getVoidTy(context),
516       std::vector<llvm::Type*>(args.size(), b_.getInt8PtrTy()),
517       /*isVarArg=*/false);
518   llvm::Function* kernel =
519       llvm::Function::Create(kernel_type, llvm::GlobalValue::ExternalLinkage,
520                              kernel_name.c_str(), module);
521 
522   // Add dereferenceable and alignment information to each of the kernel's
523   // parameters.
524   auto arg_it = kernel->arg_begin();
525   for (size_t arg_no = 0; arg_no < args.size(); ++arg_no) {
526     const BufferAllocation* alloc = args[arg_no];
527     llvm::Argument* fn_arg = &*arg_it;
528     ++arg_it;
529 
530     kernel->addDereferenceableAttr(arg_no + 1, alloc->size());
531 
532     const int64_t alignment = [&] {
533       if (alloc->is_entry_computation_parameter()) {
534         return kEntryParameterAlignBytes;
535       } else if (alloc->is_constant()) {
536         return kConstantBufferAlignBytes;
537       } else {
538         return kXlaAllocatedBufferAlignBytes;
539       }
540     }();
541 
542     kernel->addParamAttr(
543         arg_no,
544         llvm::Attribute::get(context, llvm::Attribute::Alignment, alignment));
545 
546     if (alloc->IsPreallocatedTempBuffer()) {
547       fn_arg->setName("temp_buf");
548     } else {
549       fn_arg->setName(StrCat("alloc", alloc->index()));
550     }
551   }
552 
553   AnnotateFunctionAsGpuKernel(module, kernel, &b_);
554 
555   // TODO(b/65380986): Investigate if adding fast math flags for generated
556   // kernels makes sense.
557 
558   // Update the insert point to the entry basic block.
559   llvm::BasicBlock* entry_bb =
560       llvm::BasicBlock::Create(context, /*Name=*/"entry", /*Parent=*/kernel);
561 
562   // Emit a "return void" at entry_bb's end, and set the insert point before
563   // that return instruction.
564   b_.SetInsertPoint(llvm::ReturnInst::Create(context, entry_bb));
565 
566   return kernel;
567 }
568 
GetAllocationSlice(mlir::Value v,std::string * constant_name)569 StatusOr<BufferAllocation::Slice> IrEmitterUnnested::GetAllocationSlice(
570     mlir::Value v, std::string* constant_name) {
571   return xla::gpu::GetAllocationSlice(v, ir_emitter_context_->allocations(),
572                                       constant_name);
573 }
574 
EmitConstant(mlir::Operation * op)575 Status IrEmitterUnnested::EmitConstant(mlir::Operation* op) {
576   auto get_global = mlir::cast<mlir::memref::GetGlobalOp>(op);
577   auto module = get_global->getParentOfType<mlir::ModuleOp>();
578   auto global = mlir::cast<mlir::memref::GlobalOp>(
579       module.lookupSymbol(get_global.name()));
580 
581   auto literal = global.initial_value()->dyn_cast<mlir::DenseElementsAttr>();
582   TF_RET_CHECK(literal);
583 
584   const bool should_emit_initializer = literal.getType().getNumElements() <= 1;
585 
586   TF_ASSIGN_OR_RETURN(int element_bytes,
587                       GetElementTypeBytes(literal.getType().getElementType()));
588   llvm::ArrayType* global_type = llvm::ArrayType::get(
589       b_.getInt8Ty(), literal.getType().getNumElements() * element_bytes);
590 
591   GpuExecutable::ConstantInfo info;
592   llvm::Constant* initializer;
593   if (should_emit_initializer) {
594     std::vector<uint8> content;
595     TF_RETURN_IF_ERROR(CopyDenseElementsDataToXlaFormat(literal, &content));
596     initializer = llvm::ConstantDataArray::get<uint8>(
597         ir_emitter_context_->llvm_module()->getContext(), content);
598   } else {
599     TF_RETURN_IF_ERROR(
600         CopyDenseElementsDataToXlaFormat(literal, &info.content));
601     initializer = llvm::ConstantAggregateZero::get(global_type);
602   }
603 
604   // These globals will be looked up by name by GpuExecutable so we need to
605   // give them an external linkage.  Not all of their uses are visible in
606   // the LLVM IR so we can't give then a linkage that merely preserves their
607   // names (like available_externally), we also need to ensure that they stick
608   // around even if they're "unused".
609   //
610   // We may have to be more clever here in the future if we notice that we're
611   // keeping around too many globals because of their linkage.
612   llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable(
613       global_type, /*isConstant=*/should_emit_initializer,
614       llvm::GlobalValue::ExternalLinkage,
615       /*Initializer=*/initializer, global.sym_name(),
616       /*TLMode=*/llvm::GlobalValue::NotThreadLocal,
617       /*AddressSpace=*/0,
618       /*isExternallyInitialized=*/false);
619   global_for_const->setAlignment(llvm::Align(kConstantBufferAlignBytes));
620   ir_emitter_context_->llvm_module()->getGlobalList().push_back(
621       global_for_const);
622 
623   info.symbol_name.assign(global.sym_name().begin(), global.sym_name().end());
624 
625   info.allocation_index =
626       global->getAttrOfType<mlir::IntegerAttr>("lmhlo.alloc").getInt();
627   ir_emitter_context_->constants().push_back(std::move(info));
628   return Status::OK();
629 }
630 
GetConditionalThunkConfig(mlir::lmhlo::CaseOp op,std::vector<ThunkSequence> branch_thunk_sequences)631 static ConditionalThunkConfig GetConditionalThunkConfig(
632     mlir::lmhlo::CaseOp op, std::vector<ThunkSequence> branch_thunk_sequences) {
633   ConditionalThunkConfig config;
634   config.branch_index_is_bool =
635       op.index().getType().cast<mlir::ShapedType>().getElementType().isInteger(
636           /*width=*/1);
637   config.branch_count = op.branches().size();
638   // Pass nullptr as the HloInstruction* to the branch_thunks
639   // constructors because these SequentialThunks are logically "part of"
640   // this ConditionalThunk, and shouldn't be profiled separately from it.
641   config.branch_thunks.reserve(branch_thunk_sequences.size());
642   for (auto& branch_thunk_sequence : branch_thunk_sequences) {
643     config.branch_thunks.emplace_back(new SequentialThunk(
644         Thunk::ThunkInfo(), std::move(branch_thunk_sequence)));
645   }
646   return config;
647 }
648 
EmitConditional(mlir::Operation * op)649 Status IrEmitterUnnested::EmitConditional(mlir::Operation* op) {
650   auto conditional = mlir::cast<mlir::lmhlo::CaseOp>(op);
651 
652   std::vector<ThunkSequence> branch_thunks;
653 
654   int branch_count = conditional.branches().size();
655   branch_thunks.reserve(branch_count);
656 
657   for (int j = 0; j < branch_count; ++j) {
658     mlir::Region* branch_computation = &conditional.branches()[j];
659     TF_ASSIGN_OR_RETURN(
660         auto ir_emitter,
661         IrEmitterUnnested::Create(hlo_module_config_, ir_emitter_context_));
662     TF_RETURN_IF_ERROR(ir_emitter->EmitLmhloRegion(branch_computation));
663     branch_thunks.push_back(std::move(*ir_emitter->ConsumeThunkSequence()));
664   }
665 
666   ConditionalThunkConfig config =
667       GetConditionalThunkConfig(conditional, std::move(branch_thunks));
668 
669   TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(conditional.index()));
670   AddThunkToThunkSequence(std::unique_ptr<Thunk>(
671       new ConditionalThunk(GetThunkInfo(op), std::move(config), slice)));
672   return Status::OK();
673 }
674 
675 // Input = {dynamic array(with dynamic dimension meta data at the end)}
676 // Output = {static array, dynamic_dim0, dynamic_dim1}
EmitPadToStatic(mlir::Operation * op)677 Status IrEmitterUnnested::EmitPadToStatic(mlir::Operation* op) {
678   // TODO(jurahul): Create an op to represent PadToStatic.
679   auto pad_to_static = mlir::cast<mlir::lmhlo::CustomCallOp>(op);
680   int unroll_factor = 1;
681   std::string ir_name = mlir::GetNameFromLoc(pad_to_static.getLoc());
682 
683   const Shape& input_shape = GetShape(pad_to_static.args().front());
684   TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions,
685                       CalculateLaunchDimensions(
686                           input_shape, ir_emitter_context_->gpu_device_info(),
687                           {unroll_factor}));
688   std::vector<llvm_ir::IrArray> ir_arrays;
689   TF_ASSIGN_OR_RETURN(auto kernel_thunk,
690                       BuildKernelThunk(pad_to_static, GetThunkInfo(op),
691                                        &ir_arrays, launch_dimensions));
692 
693   const llvm_ir::IrArray source_array = ir_arrays[0];
694   const llvm_ir::IrArray output_array = ir_arrays[1];
695   auto output_dim_arrays =
696       absl::Span<const llvm_ir::IrArray>(ir_arrays).subspan(2);
697 
698   // pseudo code for PadToStatic on a 2d array
699   //   int* source_array = input[0];
700   //   int* dest_array = output[0];
701   llvm::Value* source_buffer = source_array.GetBasePointer();
702   llvm::Value* raw_buffer =
703       b_.CreateBitCast(source_buffer, b_.getInt8Ty()->getPointerTo());
704 
705   // TODO(jurahul): input_shape here is the static shape of the input (which has
706   // a dynamic shape in XLA). Currently, we are mapping that to a static shaped
707   // memref. When we change that to a more appropriate representation in MLIR,
708   // fix this code to correctly deduce the static shape backing the dynamically
709   // shaped memref.
710   int64_t raw_data_size = ShapeUtil::ByteSizeOf(input_shape);
711 
712   //   int* dyn_dim0_size = source_array + meta_data_offset;
713   //   int* dyn_dim1_size = source_array + meta_data_offset + sizeof(int);
714   std::vector<llvm::Value*> dynamic_dims;
715   for (int64_t i = 1; i < pad_to_static.output().size(); ++i) {
716     // Dynamic size of each dimension is attached at the end of the source
717     // array(operand(0)). We need to extract these value.
718     const Shape& dim_shape = GetShape(pad_to_static.output()[i]);
719     TF_RET_CHECK(Shape::Equal()(dim_shape, ShapeUtil::MakeScalarShape(S32)));
720 
721     const int64_t dim_index = i - 1;
722     llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32(
723         b_.getInt8Ty(), raw_buffer, raw_data_size + dim_index * sizeof(int32));
724     llvm::Value* dyn_dim_size = b_.CreateLoad(
725         b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo()),
726         "dyn_dim_size");
727     dynamic_dims.push_back(dyn_dim_size);
728   }
729 
730   // only one thread need to store the dynamic index
731   //   int thread_id = GetThreadId();
732   //   int block_id = GetBlockId();
733   //   if (thread_id == 0 && block_id == 0) {
734   //     *output[1] = *dyn_dim0_size;
735   //     *output[2] = *dyn_dim1_size;
736   //   }
737   KernelSupportLibrary{&b_}.If("is_thred_0", IsBlock0Thread0(&b_), [&] {
738     for (int64_t i = 1; i < pad_to_static.output().size(); ++i) {
739       const int64_t dim_index = i - 1;
740       llvm::Value* dest_dim_size_address =
741           output_dim_arrays[dim_index].GetBasePointer();
742       // output[i] stores dynamic_dim_(i-1)
743       b_.CreateStore(dynamic_dims[i - 1],
744                      b_.CreateBitCast(dest_dim_size_address,
745                                       b_.getInt32Ty()->getPointerTo()));
746     }
747   });
748 
749   //     int dyn_element_total = 1;
750   //     dyn_element_total *= *dyn_dim0_size;
751   //     dyn_element_total *= *dyn_dim1_size;
752   llvm::Value* dyn_element_total = llvm::ConstantInt::get(b_.getInt32Ty(), 1);
753   for (llvm::Value* dynamic_dim : dynamic_dims) {
754     dyn_element_total = b_.CreateMul(dyn_element_total, dynamic_dim,
755                                      /*Name=*/"dyn_element_total");
756   }
757 
758   //   linear_index = block_id * threads_per_block + thread_id;
759   //   if (linear_index < max_num_element) {
760   //     Index static_index =
761   //         delinerized(linerized_index, static_dim0_size, static_dim1_size);
762   //     if (linerized_index < dyn_element_total) {
763   //       Index dyn_index =
764   //           delinerized(linerized_index, *dyn_dim0_size, *dyn_dim1_size);
765   //       dest_array[dyn_index.dim0][dyn_index.dim1] =
766   //           source_array[static_index.dim0][static_index.dim1];
767   //     }
768   //   }
769   llvm_ir::LoopEmitter::BodyEmitter body_generator =
770       [&](const llvm_ir::IrArray::Index& array_index) -> Status {
771     llvm::Value* linearIndex =
772         array_index.Linearize(input_shape.dimensions(), &b_);
773     auto if_in_dyn_bounds = llvm_ir::EmitIfThenElse(
774         b_.CreateICmpULT(linearIndex, dyn_element_total),
775         llvm_ir::IrName(ir_name, "in_dyn_bounds"), &b_, false);
776     // Set IR builder insertion point to the body of the if structure.
777     llvm_ir::SetToFirstInsertPoint(if_in_dyn_bounds.true_block, &b_);
778     llvm_ir::IrArray::Index dyn_index(linearIndex, input_shape,
779                                       absl::MakeSpan(dynamic_dims), &b_);
780     output_array.EmitWriteArrayElement(
781         dyn_index,
782         source_array.EmitReadArrayElement(array_index, &b_, /*name=*/""), &b_,
783         /*use_linear_index=*/false);
784     return Status::OK();
785   };
786 
787   const Shape& data_shape = GetShape(pad_to_static.output().front());
788   TF_RETURN_IF_ERROR(
789       ParallelLoopEmitter(body_generator, data_shape, launch_dimensions, &b_,
790                           {unroll_factor})
791           .EmitLoop(ir_name,
792                     GetIndexTypeForKernel(
793                         pad_to_static, launch_dimensions.launch_bound(), &b_)));
794   thunk_sequence_.emplace_back(std::move(kernel_thunk));
795   return Status::OK();
796 }
797 
798 // Input = {dynamic array(with dynamic dimension meta data at the end)}
799 // Output = {static array, dynamic_dim0, dynamic_dim1}
EmitSliceToDynamic(mlir::Operation * op)800 Status IrEmitterUnnested::EmitSliceToDynamic(mlir::Operation* op) {
801   // TODO(jurahul): Create an op to represent SliceToDynamic.
802   auto slice_to_dynamic = mlir::cast<mlir::lmhlo::CustomCallOp>(op);
803   int unroll_factor = 1;
804   std::string ir_name = mlir::GetNameFromLoc(slice_to_dynamic.getLoc());
805 
806   const Shape& input_shape = GetShape(slice_to_dynamic.args().front());
807   TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions,
808                       CalculateLaunchDimensions(
809                           input_shape, ir_emitter_context_->gpu_device_info(),
810                           {unroll_factor}));
811   std::vector<llvm_ir::IrArray> ir_arrays;
812   TF_ASSIGN_OR_RETURN(auto kernel_thunk,
813                       BuildKernelThunk(slice_to_dynamic, GetThunkInfo(op),
814                                        &ir_arrays, launch_dimensions));
815 
816   TF_RET_CHECK(slice_to_dynamic.output().size() == 1);
817   const Shape& data_shape = GetShape(slice_to_dynamic.output().front());
818 
819   // TODO(jurahul): data_shape here is the static shape of the output (which has
820   // a dynamic shape in XLA). Currently, we are mapping that to a static shaped
821   // memref. When we change that to a more appropriate representation in MLIR,
822   // fix this code to correctly deduce the static shape backing the dynamically
823   // shaped memref.
824 
825   // calculate the location where metadata needs to be inserted
826   //   int* dyn_dim0_size = dest_array + meta_data_offset;
827   //   int* dyn_dim1_size = dest_array + meta_data_offset + sizeof(int);
828   int32_t raw_data_size = ShapeUtil::ByteSizeOf(data_shape);
829 
830   // pseudo code for sliceToDynamic on a 2d array
831   //   int* source_array = input[0];
832   //   int* dest_array = output[0];
833   const llvm_ir::IrArray data_array = ir_arrays.back();
834   llvm::Value* dest_buffer = data_array.GetBasePointer();
835   llvm::Value* raw_buffer =
836       b_.CreateBitCast(dest_buffer, b_.getInt8Ty()->getPointerTo());
837 
838   // Load dynamic dimensions from memory.
839   std::vector<llvm::Value*> dynamic_dims;
840   for (int64_t i = 1; i < slice_to_dynamic.args().size(); ++i) {
841     // const int64 dim_index = i - 1;
842     llvm::Value* source_buffer = ir_arrays[i].GetBasePointer();
843     llvm::LoadInst* dyn_dim_size = b_.CreateLoad(source_buffer, "dyn_dim_size");
844     dynamic_dims.push_back(dyn_dim_size);
845   }
846 
847   // only one thread need to store the dynamic index
848   //   int thread_id = GetThreadId();
849   //   int block_id = GetBlockId();
850   //   if (thread_id == 0 && block_id == 0) {
851   //     *dyn_dim0_size = *output[1];
852   //     *dyn_dim1_size = *output[2];
853   //   }
854   KernelSupportLibrary{&b_}.If("is_thred_0", IsBlock0Thread0(&b_), [&] {
855     for (int64_t i = 1; i < slice_to_dynamic.args().size(); ++i) {
856       const int64_t dim_index = i - 1;
857       llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32(
858           b_.getInt8Ty(), raw_buffer,
859           raw_data_size + dim_index * sizeof(int32));
860       // output[i] stores dynamic_dim_(i-1)
861       b_.CreateStore(
862           dynamic_dims[dim_index],
863           b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo()));
864     }
865   });
866 
867   //     int dyn_element_total = 1;
868   //     dyn_element_total *= dyn_dim0_size;
869   //     dyn_element_total *= dyn_dim1_size;
870   llvm::Value* dyn_element_total = llvm::ConstantInt::get(b_.getInt32Ty(), 1);
871   for (llvm::Value* dynamic_dim : dynamic_dims) {
872     dyn_element_total = b_.CreateMul(dyn_element_total, dynamic_dim,
873                                      /*Name=*/"dyn_element_total");
874   }
875 
876   //   linear_index = block_id * threads_per_block + thread_id;
877   //   if (linear_index < max_num_element) {
878   //     Index static_index =
879   //         delinerized(linerized_index, static_dim0_size, static_dim1_size);
880   //     if (linerized_index < dyn_element_total) {
881   //       Index dyn_index =
882   //           delinerized(linerized_index, *dyn_dim0_size, *dyn_dim1_size);
883   //       dest_array[static_index.dim0][static_index.di] =
884   //           source_array[dyn_index.dim0][dyn_index.dim1];
885   //     }
886   //   }
887   llvm_ir::LoopEmitter::BodyEmitter body_generator =
888       [&](const llvm_ir::IrArray::Index& array_index) -> Status {
889     llvm::Value* linearIndex =
890         array_index.Linearize(input_shape.dimensions(), &b_);
891     auto if_in_dyn_bounds = llvm_ir::EmitIfThenElse(
892         b_.CreateICmpULT(linearIndex, dyn_element_total),
893         llvm_ir::IrName(ir_name, "in_dyn_bounds"), &b_, false);
894     // Set IR builder insertion point to the body of the if structure.
895     llvm_ir::SetToFirstInsertPoint(if_in_dyn_bounds.true_block, &b_);
896     llvm_ir::IrArray::Index dyn_index(linearIndex, input_shape,
897                                       absl::MakeSpan(dynamic_dims), &b_);
898 
899     data_array.EmitWriteArrayElement(
900         array_index,
901         ir_arrays[0].EmitReadArrayElement(dyn_index, &b_, /*name=*/"",
902                                           /*use_linear_index=*/false),
903         &b_);
904     return Status::OK();
905   };
906 
907   TF_RETURN_IF_ERROR(
908       ParallelLoopEmitter(body_generator, data_shape, launch_dimensions, &b_,
909                           {unroll_factor})
910           .EmitLoop(ir_name, GetIndexTypeForKernel(
911                                  slice_to_dynamic,
912                                  launch_dimensions.launch_bound(), &b_)));
913   thunk_sequence_.emplace_back(std::move(kernel_thunk));
914   return Status::OK();
915 }
916 
EmitConvolutionThunk(mlir::Operation * op)917 Status IrEmitterUnnested::EmitConvolutionThunk(mlir::Operation* op) {
918   using mlir::dyn_cast;
919   using mlir::lmhlo_gpu::Activation;
920   using mlir::lmhlo_gpu::ConvBackwardFilterOp;
921   using mlir::lmhlo_gpu::ConvBackwardInputOp;
922   using mlir::lmhlo_gpu::ConvForwardFusedOp;
923   using mlir::lmhlo_gpu::ConvForwardFusedSideInputOp;
924   using mlir::lmhlo_gpu::ConvForwardOp;
925 
926   // Last 2 operands of the convolution operation are the result and scratch.
927   std::vector<BufferAllocation::Slice> operand_slices;
928   int64_t num_operands = op->getNumOperands();
929   operand_slices.reserve(num_operands - 2);
930   for (mlir::Value operand : op->getOperands().drop_back(2)) {
931     TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(operand));
932     operand_slices.push_back(slice);
933   }
934 
935   mlir::Value conv_result = op->getOperand(num_operands - 2);
936   mlir::Value scratch_result = op->getOperand(num_operands - 1);
937   TF_ASSIGN_OR_RETURN(auto conv_result_slice, GetAllocationSlice(conv_result));
938   TF_ASSIGN_OR_RETURN(auto scratch_slice, GetAllocationSlice(scratch_result));
939 
940   auto apply_layout = [](const Shape& shape, mlir::ArrayAttr layout_attrib) {
941     mlir::SmallVector<int64, 4> minor_to_major = llvm::to_vector<4>(
942         llvm::map_range(layout_attrib, [](mlir::Attribute a) -> int64 {
943           return static_cast<int64>(a.cast<mlir::IntegerAttr>().getInt());
944         }));
945     return ShapeUtil::MakeShapeWithLayout(shape.element_type(),
946                                           shape.dimensions(), minor_to_major);
947   };
948 
949   GpuConvDescriptor descriptor;
950 
951   auto fill_conv_descriptor = [&](auto op) {
952     descriptor.operand0_shape = apply_layout(
953         GetShape(op->getOperand(0)), op.backend_config().operand_0_layout());
954     descriptor.operand1_shape = apply_layout(
955         GetShape(op->getOperand(1)), op.backend_config().operand_1_layout());
956     descriptor.result_shape = apply_layout(GetShape(conv_result),
957                                            op.backend_config().result_layout());
958     descriptor.dnums = ConvertConvDimensionNumbers(op.dimension_numbers());
959     descriptor.scratch_size = scratch_slice.size();
960     mlir::DenseIntElementsAttr window_strides = op.window_strides().getValue();
961     mlir::DenseIntElementsAttr padding = op.padding().getValue();
962     mlir::DenseIntElementsAttr lhs_dilation = op.lhs_dilation().getValue();
963     mlir::DenseIntElementsAttr rhs_dilation = op.rhs_dilation().getValue();
964     mlir::DenseElementsAttr window_reversal = op.window_reversal().getValue();
965     for (auto index : llvm::seq<int>(0, window_strides.getNumElements())) {
966       WindowDimension* dim = descriptor.window.add_dimensions();
967       // Window size for a convolution is the same as the kernel size.
968       // Kernel size of the convolution is operand1_shape. We need to look at
969       // the convolution dimension numbers kernel spatial dimensions to get
970       // the window size.
971       int kernel_dim = descriptor.dnums.kernel_spatial_dimensions(index);
972       dim->set_size(descriptor.operand0_shape.dimensions(kernel_dim));
973       dim->set_stride(window_strides.getValue<int64>(index));
974       dim->set_padding_low(padding.getValue<int64>(index));
975       dim->set_padding_high(padding.getValue<int64>(index));
976       dim->set_base_dilation(lhs_dilation.getValue<int64>(index));
977       dim->set_window_dilation(rhs_dilation.getValue<int64>(index));
978       dim->set_window_reversal(window_reversal.getValue<bool>(index));
979     }
980     descriptor.feature_group_count = op.feature_group_count();
981     descriptor.backend_config.set_algorithm(
982         op.backend_config().algorithm().getInt());
983     descriptor.backend_config.set_tensor_ops_enabled(
984         op.backend_config().tensor_ops_enabled().getValue());
985     descriptor.backend_config.set_conv_result_scale(
986         op.result_scale().convertToDouble());
987   };
988 
989   auto set_activation_mode = [&](auto op) -> Status {
990     TF_ASSIGN_OR_RETURN(stream_executor::dnn::ActivationMode activation_mode,
991                         ConvertConvActivationMode(op.activation_mode()));
992     descriptor.backend_config.set_activation_mode(
993         static_cast<int64>(activation_mode));
994     return Status::OK();
995   };
996 
997   if (auto conv = dyn_cast<ConvForwardOp>(op)) {
998     descriptor.kind = CudnnConvKind::kForward;
999     fill_conv_descriptor(conv);
1000   } else if (auto conv = dyn_cast<ConvBackwardInputOp>(op)) {
1001     descriptor.kind = CudnnConvKind::kBackwardInput;
1002     fill_conv_descriptor(conv);
1003   } else if (auto conv = dyn_cast<ConvBackwardFilterOp>(op)) {
1004     descriptor.kind = CudnnConvKind::kBackwardFilter;
1005     fill_conv_descriptor(conv);
1006   } else if (auto conv = dyn_cast<ConvForwardFusedOp>(op)) {
1007     descriptor.kind = CudnnConvKind::kForwardActivation;
1008     fill_conv_descriptor(conv);
1009     TF_RETURN_IF_ERROR(set_activation_mode(conv));
1010   } else if (auto conv = dyn_cast<ConvForwardFusedSideInputOp>(op)) {
1011     descriptor.kind = CudnnConvKind::kForwardActivation;
1012     fill_conv_descriptor(conv);
1013     TF_RETURN_IF_ERROR(set_activation_mode(conv));
1014     descriptor.backend_config.set_side_input_scale(
1015         conv.side_input_scale().convertToDouble());
1016   } else {
1017     return InternalError("Unexpected operation");
1018   }
1019   TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(descriptor, ""));
1020   AddThunkToThunkSequence(absl::make_unique<ConvolutionThunk>(
1021       GetThunkInfo(op), std::move(config), std::move(operand_slices),
1022       conv_result_slice, scratch_slice));
1023   return Status::OK();
1024 }
1025 
EmitGemmThunk(mlir::Operation * op)1026 Status IrEmitterUnnested::EmitGemmThunk(mlir::Operation* op) {
1027   auto make_thunk_for_gemm =
1028       [&](auto op, absl::optional<BufferAllocation::Slice> bias = absl::nullopt,
1029           absl::optional<double> gemm_bias_beta = absl::nullopt,
1030           bool implements_whole_instruction =
1031               true) -> StatusOr<std::unique_ptr<Thunk>> {
1032     TF_ASSIGN_OR_RETURN(auto lhs, GetAllocationSlice(op.lhs()));
1033     TF_ASSIGN_OR_RETURN(auto rhs, GetAllocationSlice(op.rhs()));
1034     TF_ASSIGN_OR_RETURN(auto output, GetAllocationSlice(op.output()));
1035     std::vector<BufferAllocation::Slice> inputs = {lhs, rhs};
1036     if (bias.has_value()) {
1037       inputs.push_back(bias.value());
1038     }
1039 
1040     if (IsBefThunkEnabled() && op.lhs_stride() && op.rhs_stride()) {
1041       // TODO(loreno): TFRT support for zero-strided gemm calls
1042       return CreateBefThunk(GetThunkInfo(op), op, inputs,
1043                             std::vector<BufferAllocation::Slice>{output});
1044     }
1045 
1046     GpuGemmConfig config;
1047     GemmBackendConfig& backend = config.backend_config;
1048     config.output_shape = GetShape(op.output());
1049     config.lhs_shape = GetShape(op.lhs());
1050     config.rhs_shape = GetShape(op.rhs());
1051     backend.Clear();
1052     if (op.algorithm()) {
1053       backend.set_selected_algorithm(*op.algorithm());
1054     }
1055     backend.set_alpha_real(op.alpha_real().convertToDouble());
1056     backend.set_alpha_imag(op.alpha_imag().convertToDouble());
1057     backend.set_batch_size(op.batch_size());
1058     if (gemm_bias_beta.has_value()) {
1059       backend.set_beta(gemm_bias_beta.value());
1060     }
1061     backend.set_lhs_stride(op.lhs_stride());
1062     backend.set_rhs_stride(op.rhs_stride());
1063 
1064     auto& dims = *backend.mutable_dot_dimension_numbers();
1065     auto mlir_dims = op.dot_dimension_numbers();
1066 
1067     auto fill_dims = [](mlir::DenseElementsAttr mlir_dim, auto* config_attrs) {
1068       for (llvm::APInt e : mlir_dim.getIntValues())
1069         config_attrs->Add(e.getSExtValue());
1070     };
1071     fill_dims(mlir_dims.lhs_batching_dimensions(),
1072               dims.mutable_lhs_batch_dimensions());
1073     fill_dims(mlir_dims.rhs_batching_dimensions(),
1074               dims.mutable_rhs_batch_dimensions());
1075     fill_dims(mlir_dims.lhs_contracting_dimensions(),
1076               dims.mutable_lhs_contracting_dimensions());
1077     fill_dims(mlir_dims.rhs_contracting_dimensions(),
1078               dims.mutable_rhs_contracting_dimensions());
1079 
1080     return std::unique_ptr<Thunk>(
1081         new GemmThunk(GetThunkInfo(op), std::move(config), lhs, rhs, output,
1082                       implements_whole_instruction));
1083   };
1084 
1085   TF_ASSIGN_OR_RETURN(auto thunk, [&]() -> StatusOr<std::unique_ptr<Thunk>> {
1086     if (auto gemm = mlir::dyn_cast<mlir::lmhlo_gpu::GEMMOp>(op)) {
1087       return make_thunk_for_gemm(gemm);
1088     }
1089 
1090     if (auto gemm = mlir::dyn_cast<mlir::lmhlo_gpu::GEMM_BiasOp>(op)) {
1091       double gemm_bias_beta = gemm.beta().convertToDouble();
1092       TF_ASSIGN_OR_RETURN(auto bias, GetAllocationSlice(gemm.bias()));
1093       TF_ASSIGN_OR_RETURN(auto output, GetAllocationSlice(gemm.output()));
1094 
1095       // The bias is passed inside the output buffer. If those buffers are
1096       // shared we can just use it, otherwise copy the bias values into the
1097       // output buffer first.
1098       if (bias == output) {
1099         return make_thunk_for_gemm(gemm, bias, gemm_bias_beta);
1100       }
1101 
1102       ThunkSequence thunks;
1103       thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
1104           Thunk::ThunkInfo(),
1105           /*source_buffer=*/bias,
1106           /*destination_buffer=*/output,
1107           /*mem_size=*/
1108           ShapeUtil::ByteSizeOf(GetShape(gemm.output()))));
1109       TF_ASSIGN_OR_RETURN(
1110           auto thunk,
1111           make_thunk_for_gemm(gemm, bias, gemm_bias_beta,
1112                               /*implements_whole_instruction=*/false));
1113       thunks.push_back(std::move(thunk));
1114       return std::unique_ptr<Thunk>(
1115           new SequentialThunk(GetThunkInfo(op), std::move(thunks)));
1116     }
1117 
1118     return tensorflow::errors::Internal("Unexpected op.");
1119   }());
1120 
1121   AddThunkToThunkSequence(std::move(thunk));
1122   return Status::OK();
1123 }
1124 
1125 namespace {
1126 // An MLIR value and its name as defined in the ODS spec.
1127 struct NamedValue {
1128   mlir::Value value;
1129   absl::string_view name;
1130 };
1131 
1132 // Verifies that the given batch norm is well formed for thunk emission. This
1133 // requires that all statistics operands (mean, stddev etc) are F32 types and
1134 // all the non-statistics operands need to match in shape, element type, and
1135 // layout (which maps to them having the same memref type).
VerifyBatchNormForThunkEmission(mlir::ArrayRef<NamedValue> statistics_operands,mlir::ArrayRef<NamedValue> other_operands)1136 Status VerifyBatchNormForThunkEmission(
1137     mlir::ArrayRef<NamedValue> statistics_operands,
1138     mlir::ArrayRef<NamedValue> other_operands) {
1139   for (const NamedValue& v : statistics_operands) {
1140     // Note: MLIR verification will ensure that the operands of the batchnorm
1141     // LHLO are valid memref types.
1142     if (!v.value.getType().cast<mlir::MemRefType>().getElementType().isF32()) {
1143       return Unimplemented("Operand %s of batch norm should have F32 type",
1144                            v.name);
1145     }
1146   }
1147   if (other_operands.empty()) {
1148     return Status::OK();
1149   }
1150 
1151   mlir::Type first_type = other_operands.front().value.getType();
1152   absl::string_view first_name = other_operands.front().name;
1153 
1154   for (const NamedValue& v : other_operands.drop_front(1)) {
1155     if (v.value.getType() != first_type) {
1156       return Unimplemented("%s and %s for batch norm should have same types",
1157                            v.name, first_name);
1158     }
1159   }
1160 
1161   return Status::OK();
1162 }
1163 
1164 // Determine if we enable the row optimized codegen.  When we have a
1165 // fusion with only point-wise operations, scalar broadcasting and row
1166 // broadcasting, we can trigger a kernel that vectorize the row loads.
1167 // This speed up the kernel, in particular on A100.
1168 // Returns a pair<bool, int>. The bool mean should we try to enable
1169 // row vectorization.  The int is the number of inputs with the higher
1170 // rank.
RowVectorizationEnabled(mlir::lmhlo::FusionOp fusion)1171 std::pair<bool, int> RowVectorizationEnabled(mlir::lmhlo::FusionOp fusion) {
1172   const auto is_row_major = [](mlir::Value value) {
1173     // Only tested when the inputs are row-major. So only
1174     // enable that case. Maybe it would works if only the
1175     // inner dimensions is contiguous.
1176     return LayoutUtil::IsMonotonicWithDim0Major(GetShape(value).layout());
1177   };
1178   bool row_vectorized =
1179       fusion.getFusionResults().size() == 1 &&  // Not tested with MOF.
1180       absl::c_all_of(GetHloOperands(fusion), is_row_major) &&
1181       absl::c_all_of(GetHloOutputs(fusion), is_row_major);
1182 
1183   // Check that the operations in the fusion are supported.  Each
1184   // supported operation (or category) must be manually vetted as XLA
1185   // only unrolls and relies on LLVM to vectorize. But this is brittle.
1186   // Currently tested and supported operations:
1187   // Elementwise, scalar and row broadcasting.
1188   //
1189   // We also detect at the same time if there is a row broadcasting
1190   // operation.
1191   bool some_row_broadcasting = false;
1192   auto out_rank =
1193       fusion.getFusionResults()[0].getType().cast<mlir::ShapedType>().getRank();
1194   int num_big_inputs = 0;
1195   for (mlir::Operation& op : fusion.region().front()) {
1196     if (auto load = mlir::dyn_cast<mlir::memref::TensorLoadOp>(op)) {
1197       auto rank = load.getResult().getType().cast<mlir::ShapedType>().getRank();
1198       num_big_inputs += static_cast<int>(rank == out_rank);
1199       continue;
1200     } else if (mlir::isa<mlir::memref::TensorStoreOp, mlir::lmhlo::TerminatorOp,
1201                          mlir::mhlo::ReturnOp, mlir::mhlo::ConstOp,
1202                          mlir::lmhlo::ConstOp>(op)) {
1203       continue;
1204     }
1205     HloOpcode opcode = *MhloToHloOpcode(&op);
1206     if (HloInstruction::IsOpElementwise(opcode)) {
1207       continue;
1208     }
1209 
1210     if (auto broadcast = mlir::dyn_cast<mlir::mhlo::BroadcastInDimOp>(op)) {
1211       if (broadcast.broadcast_dimensions().size() == 0) {
1212         continue;
1213       }
1214       std::vector<int64> broadcast_dimensions;
1215       for (const llvm::APInt& int_value : broadcast.broadcast_dimensions()) {
1216         broadcast_dimensions.push_back(int_value.getSExtValue());
1217       }
1218 
1219       auto rank = GetShape(broadcast.getResult()).rank();
1220       if (broadcast_dimensions.size() == 1 &&
1221           broadcast_dimensions.back() == (rank - 1)) {
1222         some_row_broadcasting = true;
1223         continue;
1224       }
1225     }
1226     VLOG(2) << "Row vectorization not enabled due to this op: "
1227             << MlirToString(&op);
1228     return std::make_pair(false, 0);
1229   }
1230   // Trigger only when there is a row broadcasting.
1231   return std::make_pair(row_vectorized && some_row_broadcasting,
1232                         num_big_inputs);
1233 }
1234 }  // namespace
1235 
EmitBatchNormThunk(mlir::Operation * op)1236 Status IrEmitterUnnested::EmitBatchNormThunk(mlir::Operation* op) {
1237   auto get_batch_norm_config = [](auto op, mlir::Value output) {
1238     CudnnBatchNormConfig config;
1239     config.output_shape = GetShape(output);
1240     config.output_type = config.output_shape.element_type();
1241     config.epsilon = op.epsilon().convertToFloat();
1242     config.feature_index = op.feature_index();
1243     return config;
1244   };
1245 
1246   // The statistics operands for batch norm operations need to be FP32 type.
1247   // And the rest of the operands need match in shape, layout, and element type
1248   // to match.
1249   if (auto bn_train =
1250           mlir::dyn_cast<mlir::lmhlo_gpu::BatchNormTrainingOp>(op)) {
1251     TF_RETURN_IF_ERROR(VerifyBatchNormForThunkEmission(
1252         /*statistics_operands=*/
1253         {{bn_train.scale(), "scale"},
1254          {bn_train.offset(), "offset"},
1255          {bn_train.batch_mean(), "batch_mean"},
1256          {bn_train.batch_stddev(), "batch_stddev"}},
1257         /*other_operands=*/
1258         {{bn_train.operand(), "operand"}, {bn_train.output(), "output"}}));
1259     TF_ASSIGN_OR_RETURN(auto operand, GetAllocationSlice(bn_train.operand()));
1260     TF_ASSIGN_OR_RETURN(auto scale, GetAllocationSlice(bn_train.scale()));
1261     TF_ASSIGN_OR_RETURN(auto offset, GetAllocationSlice(bn_train.offset()));
1262 
1263     // BatchNormTraining returns a tuple of three elements: data, calculated
1264     // mean, and calculated 1/sqrt(variance + epsilon).
1265     TF_ASSIGN_OR_RETURN(auto output_data,
1266                         GetAllocationSlice(bn_train.output()));
1267     TF_ASSIGN_OR_RETURN(auto output_mean,
1268                         GetAllocationSlice(bn_train.batch_mean()));
1269     TF_ASSIGN_OR_RETURN(auto output_inv_stddev,
1270                         GetAllocationSlice(bn_train.batch_stddev()));
1271 
1272     AddThunkToThunkSequence(
1273         absl::make_unique<CudnnBatchNormForwardTrainingThunk>(
1274             GetThunkInfo(op),
1275             /*config=*/get_batch_norm_config(bn_train, bn_train.output()),
1276             /*operand=*/operand,
1277             /*scale=*/scale,
1278             /*offset=*/offset,
1279             /*output_data=*/output_data,
1280             /*output_mean=*/output_mean,
1281             /*output_inv_stddev=*/output_inv_stddev));
1282     return Status::OK();
1283   }
1284 
1285   if (auto bn_grad = mlir::dyn_cast<mlir::lmhlo_gpu::BatchNormGradOp>(op)) {
1286     TF_RETURN_IF_ERROR(VerifyBatchNormForThunkEmission(
1287         /*statistics_operands=*/
1288         {{bn_grad.scale(), "scale"},
1289          {bn_grad.mean(), "mean"},
1290          {bn_grad.stddev(), "stddev"},
1291          {bn_grad.grad_scale(), "grad_scale"},
1292          {bn_grad.grad_offset(), "grad_offset"}},
1293         /*other_operands=*/
1294         {{bn_grad.operand(), "operand"},
1295          {bn_grad.grad_output(), "grad_output"},
1296          {bn_grad.grad_operand(), "grad_operand"}}));
1297 
1298     TF_ASSIGN_OR_RETURN(auto operand, GetAllocationSlice(bn_grad.operand()));
1299     TF_ASSIGN_OR_RETURN(auto scale, GetAllocationSlice(bn_grad.scale()));
1300     TF_ASSIGN_OR_RETURN(auto mean, GetAllocationSlice(bn_grad.mean()));
1301     TF_ASSIGN_OR_RETURN(auto inv_stddev, GetAllocationSlice(bn_grad.stddev()));
1302     TF_ASSIGN_OR_RETURN(auto grad_output,
1303                         GetAllocationSlice(bn_grad.grad_output()));
1304 
1305     // BatchNormGrad returns a tuple of three elements: grad_data, grad_scale,
1306     // grad_offset.
1307     TF_ASSIGN_OR_RETURN(auto output_grad_data,
1308                         GetAllocationSlice(bn_grad.grad_operand()));
1309     TF_ASSIGN_OR_RETURN(auto output_grad_scale,
1310                         GetAllocationSlice(bn_grad.grad_scale()));
1311     TF_ASSIGN_OR_RETURN(auto output_grad_offset,
1312                         GetAllocationSlice(bn_grad.grad_offset()));
1313 
1314     CudnnBatchNormConfig config;
1315     config.output_shape = GetShape(bn_grad.grad_output());
1316     config.output_type = config.output_shape.element_type();
1317     config.epsilon = bn_grad.epsilon().convertToFloat();
1318     config.feature_index = bn_grad.feature_index();
1319 
1320     AddThunkToThunkSequence(absl::make_unique<CudnnBatchNormBackwardThunk>(
1321         GetThunkInfo(op),
1322         /*config=*/get_batch_norm_config(bn_grad, bn_grad.grad_output()),
1323         /*operand=*/operand,
1324         /*scale=*/scale,
1325         /*mean=*/mean,
1326         /*inv_stddev=*/inv_stddev,
1327         /*grad_output=*/grad_output,
1328         /*output_grad_data=*/output_grad_data,
1329         /*output_grad_scale=*/output_grad_scale,
1330         /*output_grad_offset=*/output_grad_offset));
1331     return Status::OK();
1332   }
1333 
1334   if (auto bn_inference =
1335           mlir::dyn_cast<mlir::lmhlo_gpu::BatchNormInferenceOp>(op)) {
1336     TF_RETURN_IF_ERROR(
1337         VerifyBatchNormForThunkEmission(/*statistics_operands=*/
1338                                         {{bn_inference.scale(), "scale"},
1339                                          {bn_inference.offset(), "offset"},
1340                                          {bn_inference.mean(), "mean"},
1341                                          {bn_inference.stddev(), "stddev"}},
1342                                         /*other_operands=*/
1343                                         {{bn_inference.operand(), "operand"},
1344                                          {bn_inference.output(), "output"}}));
1345 
1346     TF_ASSIGN_OR_RETURN(auto operand,
1347                         GetAllocationSlice(bn_inference.operand()));
1348     TF_ASSIGN_OR_RETURN(auto scale, GetAllocationSlice(bn_inference.scale()));
1349     TF_ASSIGN_OR_RETURN(auto offset, GetAllocationSlice(bn_inference.offset()));
1350     TF_ASSIGN_OR_RETURN(auto mean, GetAllocationSlice(bn_inference.mean()));
1351     TF_ASSIGN_OR_RETURN(auto variance,
1352                         GetAllocationSlice(bn_inference.stddev()));
1353     TF_ASSIGN_OR_RETURN(auto output, GetAllocationSlice(bn_inference.output()));
1354 
1355     AddThunkToThunkSequence(absl::make_unique<
1356                             CudnnBatchNormForwardInferenceThunk>(
1357         GetThunkInfo(op),
1358         /*config=*/get_batch_norm_config(bn_inference, bn_inference.output()),
1359         /*operand=*/operand,
1360         /*scale=*/scale,
1361         /*offset=*/offset,
1362         /*mean=*/mean,
1363         /*variance=*/variance,
1364         /*output=*/output));
1365     return Status::OK();
1366   }
1367 
1368   return Unimplemented("Unsupported batch norm operation");
1369 }
1370 
1371 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
EmitCholeskyThunk(mlir::Operation * op)1372 Status IrEmitterUnnested::EmitCholeskyThunk(mlir::Operation* op) {
1373   auto cholesky_op = mlir::cast<mlir::lmhlo_gpu::CholeskyOp>(op);
1374 
1375   const Shape shape = GetShape(cholesky_op.input());
1376   int ndim = shape.dimensions_size();
1377   CHECK_GE(ndim, 2);
1378   int64_t n = shape.dimensions(ndim - 1);
1379 
1380   const auto& dims = shape.dimensions();
1381   int64_t batch_size =
1382       std::accumulate(dims.begin(), dims.end() - 2, int64{1},
1383                       [](int64_t a, int64_t b) { return a * b; });
1384 
1385   TF_ASSIGN_OR_RETURN(auto operand_buffer,
1386                       GetAllocationSlice(cholesky_op.input()));
1387   TF_ASSIGN_OR_RETURN(auto a_buffer, GetAllocationSlice(cholesky_op.output()));
1388   TF_ASSIGN_OR_RETURN(auto workspace_buffer,
1389                       GetAllocationSlice(cholesky_op.scratch()));
1390   TF_ASSIGN_OR_RETURN(auto info_buffer, GetAllocationSlice(cholesky_op.info()));
1391 
1392   ThunkSequence thunks;
1393 
1394   if (operand_buffer != a_buffer) {
1395     thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
1396         GetThunkInfo(op),
1397         /*source_address=*/operand_buffer,
1398         /*destination_buffer=*/a_buffer,
1399         /*mem_size=*/ShapeUtil::ByteSizeOf(shape)));
1400   }
1401 
1402   CholeskyOptions options;
1403   options.set_lower(cholesky_op.is_lower());
1404   thunks.push_back(absl::make_unique<CholeskyThunk>(
1405       GetThunkInfo(op), options, a_buffer, workspace_buffer, info_buffer,
1406       shape.element_type(), batch_size, n));
1407 
1408   // Elide the sequential thunk if there's no copy.
1409   if (thunks.size() == 1) {
1410     AddThunkToThunkSequence(std::move(thunks[0]));
1411   } else {
1412     AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
1413         GetThunkInfo(op), std::move(thunks)));
1414   }
1415 
1416   return Status::OK();
1417 }
1418 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1419 
EmitCustomCallThunk(mlir::Operation * op)1420 Status IrEmitterUnnested::EmitCustomCallThunk(mlir::Operation* op) {
1421   auto custom_call = mlir::cast<mlir::lmhlo::CustomCallOp>(op);
1422   const std::string call_target_name = custom_call.call_target_name().str();
1423 
1424   void* call_target = CustomCallTargetRegistry::Global()->Lookup(
1425       call_target_name, std::string(platform_name()));
1426   if (!call_target) {
1427     return Unimplemented(
1428         "No registered implementation for custom call to \"%s\"",
1429         call_target_name);
1430   }
1431 
1432   std::vector<CustomCallThunk::OptionalSlice> operands;
1433   std::vector<CustomCallThunk::OptionalSlice> results;
1434 
1435   if (custom_call.target_arg_mapping()) {
1436     auto values_to_slices_with_token_holes =
1437         [&](mlir::ValueRange operands, mlir::ArrayAttr op_to_target_mapping,
1438             mlir::IntegerAttr num_target)
1439         -> StatusOr<std::vector<CustomCallThunk::OptionalSlice>> {
1440       std::vector<CustomCallThunk::OptionalSlice> slices(num_target.getInt());
1441       for (auto index_and_value_it :
1442            llvm::zip(op_to_target_mapping, operands)) {
1443         mlir::Attribute index_attr = std::get<0>(index_and_value_it);
1444         mlir::Value value = std::get<1>(index_and_value_it);
1445         int64_t index = index_attr.cast<mlir::IntegerAttr>().getInt();
1446         TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
1447                             GetAllocationSlice(value));
1448         slices[index] = slice;
1449       }
1450       return slices;
1451     };
1452 
1453     mlir::lmhlo::CustomCallTargetArgMapping target_mapping =
1454         *custom_call.target_arg_mapping();
1455     TF_ASSIGN_OR_RETURN(
1456         operands, values_to_slices_with_token_holes(
1457                       custom_call.args(), target_mapping.args_to_target_args(),
1458                       target_mapping.num_args()));
1459     TF_ASSIGN_OR_RETURN(results, values_to_slices_with_token_holes(
1460                                      custom_call.output(),
1461                                      target_mapping.results_to_target_results(),
1462                                      target_mapping.num_results()));
1463   } else {
1464     auto values_to_slices = [&](mlir::ValueRange values)
1465         -> StatusOr<std::vector<CustomCallThunk::OptionalSlice>> {
1466       std::vector<CustomCallThunk::OptionalSlice> slices;
1467       for (mlir::Value value : values) {
1468         TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
1469                             GetAllocationSlice(value));
1470         slices.push_back(slice);
1471       }
1472       return slices;
1473     };
1474 
1475     TF_ASSIGN_OR_RETURN(operands, values_to_slices(custom_call.args()));
1476     TF_ASSIGN_OR_RETURN(results, values_to_slices(custom_call.output()));
1477   }
1478 
1479   CustomCallThunk::CustomCallTarget custom_call_target;
1480 
1481   // For information about this calling convention, see
1482   // xla/g3doc/custom_call.md.
1483   switch (custom_call.api_version()) {
1484     case mlir::mhlo::CustomCallApiVersion::API_VERSION_ORIGINAL:
1485       using original_call_type =
1486           void (*)(CustomCallThunk::Stream /*stream*/, void** /*buffers*/,
1487                    const char* /*opaque*/, size_t /*opaque_len*/);
1488       custom_call_target = [call_target](CustomCallThunk::Stream stream,
1489                                          void** buffers, const char* opaque,
1490                                          size_t opaque_len,
1491                                          XlaCustomCallStatus*) {
1492         auto typed_call_target =
1493             reinterpret_cast<original_call_type>(call_target);
1494         typed_call_target(stream, buffers, opaque, opaque_len);
1495       };
1496       break;
1497     case mlir::mhlo::CustomCallApiVersion::API_VERSION_STATUS_RETURNING:
1498       using status_returning_call_type =
1499           void (*)(CustomCallThunk::Stream /*stream*/, void** /*buffers*/,
1500                    const char* /*opaque*/, size_t /*opaque_len*/,
1501                    XlaCustomCallStatus* /*status*/);
1502       custom_call_target =
1503           reinterpret_cast<status_returning_call_type>(call_target);
1504       break;
1505     default:
1506       return InternalError("Unknown custom-call API version enum value: %d",
1507                            custom_call.api_version());
1508   }
1509 
1510   AddThunkToThunkSequence(absl::make_unique<CustomCallThunk>(
1511       GetThunkInfo(op), std::move(custom_call_target), std::move(operands),
1512       std::move(results), custom_call.backend_config().str()));
1513   return Status::OK();
1514 }
1515 
EmitFftThunk(mlir::Operation * op)1516 Status IrEmitterUnnested::EmitFftThunk(mlir::Operation* op) {
1517   auto fft_op = mlir::cast<mlir::lmhlo::FftOp>(op);
1518   const Shape operand_shape = GetShape(fft_op.operand());
1519   const Shape output_shape = GetShape(fft_op.output());
1520   TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(operand_shape.layout()));
1521   TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(output_shape.layout()));
1522 
1523   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice arg_slice,
1524                       GetAllocationSlice(fft_op.operand()));
1525   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice dest_slice,
1526                       GetAllocationSlice(fft_op.output()));
1527   TF_ASSIGN_OR_RETURN(xla::FftType fft_type, ConvertFftType(fft_op.fft_type()));
1528   auto fft_length_values = fft_op.fft_length().getValues<int64>();
1529   std::vector<int64> fft_length(fft_length_values.begin(),
1530                                 fft_length_values.end());
1531   AddThunkToThunkSequence(
1532       absl::make_unique<FftThunk>(GetThunkInfo(op), fft_type, fft_length,
1533                                   /*input_buffer=*/arg_slice,
1534                                   /*output_buffer=*/dest_slice,
1535                                   /*input_shape=*/operand_shape,
1536                                   /*output_shape=*/output_shape));
1537   return Status::OK();
1538 }
1539 
EmitTriangularSolve(mlir::Operation * op)1540 Status IrEmitterUnnested::EmitTriangularSolve(mlir::Operation* op) {
1541   auto triangular_solve_op = mlir::cast<mlir::lmhlo::TriangularSolveOp>(op);
1542   auto has_fortran_layout = [](mlir::DenseIntElementsAttr layout_attr) {
1543     int64_t n = layout_attr.getNumElements();
1544     return layout_attr.getValue<int64_t>({0}) == n - 2 &&
1545            layout_attr.getValue<int64_t>({1}) == n - 1;
1546   };
1547   TF_RET_CHECK(has_fortran_layout(triangular_solve_op.layout_a()));
1548   TF_RET_CHECK(has_fortran_layout(triangular_solve_op.layout_b()));
1549   TF_RET_CHECK(has_fortran_layout(triangular_solve_op.layout_output()));
1550 
1551   const Shape b_shape = GetShape(triangular_solve_op.b());
1552 
1553   const Shape output_shape = GetShape(triangular_solve_op.output());
1554 
1555   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice a_slice,
1556                       GetAllocationSlice(triangular_solve_op.a()));
1557   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice b_slice,
1558                       GetAllocationSlice(triangular_solve_op.b()));
1559   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
1560                       GetAllocationSlice(triangular_solve_op.output()));
1561   TF_ASSIGN_OR_RETURN(TriangularSolveOptions_Transpose transpose_a,
1562                       ConvertTranspose(triangular_solve_op.transpose_a()));
1563 
1564   ThunkSequence thunks;
1565 
1566   // Triangular solve is in-place on 'b', so copy 'b' to the output if they
1567   // aren't the same buffer.
1568   if (b_slice != output_slice) {
1569     thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
1570         Thunk::ThunkInfo(),
1571         /*source_address=*/b_slice,
1572         /*destination_buffer=*/output_slice,
1573         /*mem_size=*/ShapeUtil::ByteSizeOf(b_shape)));
1574   }
1575 
1576   int64_t m = b_shape.dimensions(b_shape.rank() - 2);
1577   int64_t n = b_shape.dimensions(b_shape.rank() - 1);
1578   int64_t batch_size = std::accumulate(
1579       b_shape.dimensions().begin(), b_shape.dimensions().end() - 2, int64{1},
1580       [](int64_t a, int64_t b) { return a * b; });
1581   int64_t elem_size =
1582       ShapeUtil::ByteSizeOfPrimitiveType(output_shape.element_type());
1583   int64_t a_batch_stride =
1584       triangular_solve_op.left_side() ? m * m * elem_size : n * n * elem_size;
1585   int64_t b_batch_stride = m * n * elem_size;
1586   TriangularSolveOptions options;
1587   options.set_left_side(triangular_solve_op.left_side());
1588   options.set_lower(triangular_solve_op.lower());
1589   options.set_unit_diagonal(triangular_solve_op.unit_diagonal());
1590   options.set_transpose_a(transpose_a);
1591   thunks.push_back(absl::make_unique<TriangularSolveThunk>(
1592       GetThunkInfo(op), options,
1593       /*a_input_buffer=*/a_slice,
1594       /*b_input_buffer=*/output_slice, output_shape.element_type(), batch_size,
1595       m, n, a_batch_stride, b_batch_stride));
1596 
1597   // Elide the sequential thunk if there's no copy.
1598   if (thunks.size() == 1) {
1599     AddThunkToThunkSequence(std::move(thunks[0]));
1600   } else {
1601     AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
1602         GetThunkInfo(op), std::move(thunks)));
1603   }
1604   return Status::OK();
1605 }
1606 
1607 // Convert the following form of fusion region:
1608 //   fusion() {
1609 //     %0 = tensor_load %external_memref0
1610 //     %1 = tensor_load %external_memref1
1611 //     ...
1612 //     tensor_store %ret, %external_memref2
1613 //   }
1614 // to
1615 //   fusion(%external_memref0, %external_memref1) (^bb(%0, %1) {
1616 //     ...
1617 //     mhlo.return %ret
1618 //   })
1619 //
1620 // So that it's suitable for MHLO -> XLA HLO conversion.
1621 // This function won't be needed once ElementalIrEmitter migrates to take MHLO
1622 // instead.
ProcessFusionForConversion(mlir::Region * region,std::vector<Shape> * operand_shapes,std::vector<Shape> * output_shapes)1623 static Status ProcessFusionForConversion(mlir::Region* region,
1624                                          std::vector<Shape>* operand_shapes,
1625                                          std::vector<Shape>* output_shapes) {
1626   std::vector<mlir::memref::TensorLoadOp> loads;
1627   std::vector<mlir::memref::TensorStoreOp> stores;
1628 
1629   region->walk([&](mlir::memref::TensorLoadOp load) {
1630     if (load.memref().getParentRegion() != region) {
1631       loads.push_back(load);
1632     }
1633   });
1634 
1635   region->walk([&](mlir::memref::TensorStoreOp store) {
1636     if (store.memref().getParentRegion() != region) {
1637       stores.push_back(store);
1638     }
1639   });
1640 
1641   for (auto load : loads) {
1642     auto arg = region->addArgument(load.getType());
1643     load.replaceAllUsesWith(arg);
1644     Shape shape = GetShape(load.getResult());
1645     operand_shapes->push_back(std::move(shape));
1646     load.erase();
1647   }
1648 
1649   std::vector<mlir::Value> returned_values;
1650   for (auto store : stores) {
1651     Shape shape = GetShape(store.memref());
1652     output_shapes->push_back(shape);
1653 
1654     returned_values.push_back(store.tensor());
1655     store.erase();
1656   }
1657 
1658   region->back().back().erase();
1659   auto b = mlir::OpBuilder::atBlockEnd(&region->back());
1660   auto loc = returned_values[0].getLoc();
1661   b.create<mlir::mhlo::ReturnOp>(loc, returned_values);
1662   return Status::OK();
1663 }
1664 
1665 // TODO(timshen): update the comment once the HandleFusion code path deleted.
1666 //
1667 // This is migrated from IrEmitter::HandleFusion() with IrEmitterUnnested as the
1668 // subclass. The logic is de-virtualized and less scattered.
EmitLoopFusion(mlir::Operation * op)1669 Status IrEmitterUnnested::EmitLoopFusion(mlir::Operation* op) {
1670   auto fusion = mlir::cast<mlir::lmhlo::FusionOp>(op);
1671   MlirEmitterContext context;
1672   context.SetOperation(fusion);
1673 
1674   TF_ASSIGN_OR_RETURN(const HloComputation* fused_computation,
1675                       GetOrCreateSubComputationFromRegion(&fusion.region(),
1676                                                           /*is_fusion=*/true));
1677 
1678   int unroll_factor;
1679   if (!MayPreventVectorization(fusion)) {
1680     unroll_factor = ComputeMaxUnrollFactor(fusion, hlo_module_config_);
1681   } else {
1682     unroll_factor = 1;
1683   }
1684 
1685   bool row_vectorized;
1686   int num_big_inputs;
1687   std::tie(row_vectorized, num_big_inputs) = RowVectorizationEnabled(fusion);
1688   bool few_waves = [fusion, row_vectorized, num_big_inputs]() mutable {
1689     for (mlir::Operation& op : fusion.region().front()) {
1690       if (mlir::isa<mlir::memref::TensorLoadOp, mlir::memref::TensorStoreOp,
1691                     mlir::lmhlo::TerminatorOp, mlir::mhlo::ReturnOp,
1692                     mlir::mhlo::ConstOp>(op)) {
1693         continue;
1694       }
1695       HloOpcode opcode = *MhloToHloOpcode(&op);
1696       if (HloInstruction::IsOpElementwise(opcode)) {
1697         continue;
1698       }
1699       if (auto broadcast = mlir::dyn_cast<mlir::mhlo::BroadcastInDimOp>(op)) {
1700         if (broadcast.broadcast_dimensions().empty() ||
1701             // More then 2 bit inputs cause one speed regression.
1702             (row_vectorized && num_big_inputs <= 2)) {
1703           continue;
1704         }
1705       }
1706       VLOG(2) << "few_waves not enabled due to: " << MlirToString(&op);
1707       return false;
1708     }
1709     return true;
1710   }();
1711 
1712   Shape element_shape = context.output_shapes[0];
1713   LaunchDimensionsConfig launch_config{unroll_factor, few_waves,
1714                                        row_vectorized};
1715   // Check that the shapes is supported.
1716   if (launch_config.row_vectorized &&
1717       ThreadsPerBlockRowVectorized(element_shape,
1718                                    ir_emitter_context_->gpu_device_info(),
1719                                    launch_config) <= 0) {
1720     VLOG(2) << "Cancelling row_vectorization as the shape isn't supported.";
1721     launch_config.row_vectorized = false;
1722     launch_config.few_waves = false;
1723   }
1724 
1725   TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions,
1726                       CalculateLaunchDimensions(
1727                           element_shape, ir_emitter_context_->gpu_device_info(),
1728                           launch_config));
1729 
1730   std::vector<llvm_ir::IrArray> ir_arrays;
1731   Thunk* kernel_thunk;
1732   {
1733     TF_ASSIGN_OR_RETURN(std::unique_ptr<KernelThunk> kernel_thunk_ptr,
1734                         BuildKernelThunk(fusion, GetThunkInfo(op), &ir_arrays,
1735                                          launch_dimensions));
1736     kernel_thunk = kernel_thunk_ptr.get();
1737     thunk_sequence_.emplace_back(std::move(kernel_thunk_ptr));
1738   }
1739 
1740   auto operand_arrays =
1741       absl::MakeSpan(ir_arrays).subspan(0, context.operand_shapes.size());
1742   auto output_element_arrays = absl::MakeSpan(ir_arrays).subspan(
1743       context.operand_shapes.size(), context.output_shapes.size());
1744 
1745   GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_,
1746                                           GetNestedComputer());
1747   FusedIrEmitter fused_emitter(&elemental_emitter);
1748 
1749   for (int i = 0; i < context.operand_shapes.size(); i++) {
1750     auto* builder = &b_;
1751     auto ir_array = operand_arrays[i];
1752     fused_emitter.BindGenerator(
1753         fused_computation->parameter_instruction(i),
1754         [builder, ir_array](llvm_ir::IrArray::Index index) {
1755           return ir_array.EmitReadArrayElement(index, builder);
1756         });
1757   }
1758   TF_ASSIGN_OR_RETURN(
1759       auto element_generator,
1760       fused_emitter.GetGenerator(fused_computation->root_instruction()));
1761 
1762   llvm::Type* index_type =
1763       GetIndexTypeForKernel(fusion, launch_dimensions.launch_bound(), &b_);
1764 
1765   if (context.output_shapes.size() > 1) {
1766     // For multioutput fusion, we need to emit each operand and the root.
1767     TF_RETURN_IF_ERROR(
1768         ParallelLoopEmitter(element_generator, output_element_arrays,
1769                             launch_dimensions, &b_, launch_config)
1770             .EmitLoop(context.name, index_type));
1771   } else {
1772     TF_RETURN_IF_ERROR(
1773         ParallelLoopEmitter(element_generator, output_element_arrays[0],
1774                             launch_dimensions, &b_, launch_config)
1775             .EmitLoop(context.name, index_type));
1776   }
1777 
1778   b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator());
1779   return Status::OK();
1780 }
1781 
EmitFusion(mlir::Operation * op)1782 Status IrEmitterUnnested::EmitFusion(mlir::Operation* op) {
1783   auto fusion_op = mlir::cast<mlir::lmhlo::FusionOp>(op);
1784   const bool is_single_instruction = IsSingleInstructionFusion(fusion_op);
1785 
1786   // Infer the layout of fusion internal nodes.
1787   const FusionLayoutAnalysis layout_analysis(fusion_op);
1788 
1789   auto fusion_results = fusion_op.getFusionResults();
1790   TF_RET_CHECK(!fusion_results.empty());
1791   if (fusion_results.size() > 1) {
1792     // In the case of root tuple, it can be either reduce or slice input
1793     // fusion.
1794     if (IsInputFusibleSlices(op, /*verify_no_strides=*/true)) {
1795       // The emitter doesn't support all cases. If it's not supported, fallback
1796       // to ElementalIrEmitter.
1797       auto status = EmitInputFusibleNonStridedSlices(op);
1798       if (status.code() == tensorflow::error::FAILED_PRECONDITION) {
1799         return EmitLoopFusion(op);
1800       }
1801       return status;
1802     }
1803 
1804     const bool is_parallel_reduce =
1805         absl::c_any_of(fusion_results, [&layout_analysis](mlir::Value result) {
1806           mlir::Operation* maybe_reduce = result.getDefiningOp();
1807           return maybe_reduce->getNumResults() == 1 &&
1808                  IsReductionFromOrToContiguousDimensions(maybe_reduce,
1809                                                          layout_analysis);
1810         });
1811 
1812     if (is_parallel_reduce) {
1813       return EmitUnnestedReduction(op, layout_analysis);
1814     }
1815   }
1816 
1817   mlir::Operation* fusion_root = fusion_results[0].getDefiningOp();
1818   if (mlir::isa<mlir::mhlo::ScatterOp>(fusion_root)) {
1819     TF_ASSIGN_OR_RETURN(
1820         const HloComputation* fused_computation,
1821         GetOrCreateSubComputationFromRegion(&fusion_op.region(),
1822                                             /*is_fusion=*/true));
1823     auto* root = fused_computation->root_instruction();
1824 
1825     ThunkSequence thunks;
1826     // The initialization from 'operand' is using different loop bounds, so
1827     // emit it in a separate kernel. Treat it like a loop fusion, writing to
1828     // the output buffer.
1829     {
1830       auto unroll_factor =
1831           ComputeMaxUnrollFactor(fusion_op, hlo_module_config_);
1832       const Shape& element_shape = root->shape();
1833       TF_ASSIGN_OR_RETURN(
1834           LaunchDimensions launch_dimensions,
1835           CalculateLaunchDimensions(element_shape,
1836                                     ir_emitter_context_->gpu_device_info(),
1837                                     {unroll_factor, /*few_waves=*/false}));
1838 
1839       std::vector<llvm_ir::IrArray> ir_arrays;
1840       TF_ASSIGN_OR_RETURN(auto operand_thunk,
1841                           BuildKernelThunk(op, Thunk::ThunkInfo(), &ir_arrays,
1842                                            launch_dimensions));
1843       thunks.push_back(std::move(operand_thunk));
1844 
1845       GpuElementalIrEmitter operand_elemental_emitter(
1846           hlo_module_config_, ir_emitter_context_->llvm_module(), &b_,
1847           GetNestedComputer());
1848       FusedIrEmitter operand_fused_emitter(&operand_elemental_emitter);
1849       for (int i = 0; i < fused_computation->num_parameters(); i++) {
1850         auto fused_operand = fused_computation->parameter_instruction(i);
1851         operand_fused_emitter.BindGenerator(
1852             fused_operand, [this, &ir_arrays, i,
1853                             fused_operand](llvm_ir::IrArray::Index index) {
1854               return ir_arrays[i].EmitReadArrayElement(index, &b_,
1855                                                        fused_operand->name());
1856             });
1857       }
1858       TF_ASSIGN_OR_RETURN(auto generator,
1859                           operand_fused_emitter.GetGenerator(root->operand(0)));
1860 
1861       TF_RETURN_IF_ERROR(
1862           ParallelLoopEmitter(generator, ir_arrays.back(), launch_dimensions,
1863                               &b_, {unroll_factor})
1864               .EmitLoop(IrName(mlir::GetNameFromLoc(fusion_op.getLoc())),
1865                         GetIndexTypeForKernel(
1866                             fusion_op, launch_dimensions.launch_bound(), &b_)));
1867     }
1868 
1869     // Now build the actual scatter, reading and writing to the freshly
1870     // filled output buffer.
1871     {
1872       const Shape& updates_shape = root->operand(2)->shape();
1873       TF_ASSIGN_OR_RETURN(
1874           LaunchDimensions launch_dimensions,
1875           CalculateLaunchDimensions(updates_shape,
1876                                     ir_emitter_context_->gpu_device_info()));
1877       std::vector<llvm_ir::IrArray> ir_arrays;
1878       TF_ASSIGN_OR_RETURN(auto scatter_thunk,
1879                           BuildKernelThunk(op, Thunk::ThunkInfo(), &ir_arrays,
1880                                            launch_dimensions));
1881       thunks.push_back(std::move(scatter_thunk));
1882       // Spin up a new fused emitter for the scatter kernel and emit it.
1883       GpuElementalIrEmitter scatter_elemental_emitter(
1884           hlo_module_config_, ir_emitter_context_->llvm_module(), &b_,
1885           GetNestedComputer());
1886       FusedIrEmitter scatter_fused_emitter(&scatter_elemental_emitter);
1887       for (int i = 0; i < fused_computation->num_parameters(); i++) {
1888         auto fused_operand = fused_computation->parameter_instruction(i);
1889         scatter_fused_emitter.BindGenerator(
1890             fused_operand, [this, &ir_arrays, i,
1891                             fused_operand](llvm_ir::IrArray::Index index) {
1892               return ir_arrays[i].EmitReadArrayElement(index, &b_,
1893                                                        fused_operand->name());
1894             });
1895       }
1896 
1897       TF_ASSIGN_OR_RETURN(const auto dim_numbers,
1898                           mlir::LhloDialectEmitter::GetScatterDimensionNumbers(
1899                               root, fusion_op.getContext()));
1900 
1901       ScatterDescriptor desc;
1902       desc.name = IrName(root);
1903       desc.operand_shape = root->operand(0)->shape();
1904       desc.scatter_indices_shape = root->operand(1)->shape();
1905       desc.updates_shape = updates_shape;
1906       desc.dim_numbers = dim_numbers;
1907       desc.unique_indices = root->unique_indices();
1908       desc.update_computation = root->called_computations()[0];
1909       desc.output = ir_arrays.back();
1910       TF_ASSIGN_OR_RETURN(desc.scatter_indices_gen,
1911                           scatter_fused_emitter.GetGenerator(root->operand(1)));
1912       TF_ASSIGN_OR_RETURN(desc.updates_gen,
1913                           scatter_fused_emitter.GetGenerator(root->operand(2)));
1914       desc.get_index_type = [&](int64_t launch_size) {
1915         return GetIndexTypeForKernel(root, launch_size, &b_);
1916       };
1917 
1918       TF_RETURN_IF_ERROR(
1919           EmitScatter(desc, thunks.back().get(), launch_dimensions));
1920     }
1921     AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
1922         GetThunkInfo(op), std::move(thunks)));
1923     return Status::OK();
1924   }
1925 
1926   // HandleFusion specializes reduction from a multi-dimensional array to
1927   // a 1D array. The specialized version requires a initializer thunk that
1928   // initializes the output array to the initial value of the reduce.
1929   if (mlir::isa<mlir::mhlo::ReduceOp>(fusion_root) &&
1930       fusion_root->getNumResults() == 1 &&
1931       IsReductionFromOrToContiguousDimensions(fusion_root, layout_analysis)) {
1932     return EmitUnnestedReduction(op, layout_analysis);
1933   }
1934 
1935   if (!is_single_instruction &&
1936       CanEmitFusedDynamicUpdateSliceInPlaceForGpu(
1937           fusion_op, ir_emitter_context_->allocations())) {
1938     // Fusion node with dynamic-update-slice as the root where the op's input
1939     // (i.e. array to update) shares the same slice as its output.  In this case
1940     // we have a special algorithm that modifies the output in place without
1941     // touching the un-updated elements.
1942     CHECK_EQ(1, GetHloOutputs(op).size());
1943 
1944     TF_ASSIGN_OR_RETURN(
1945         const HloComputation* fused_computation,
1946         GetOrCreateSubComputationFromRegion(&fusion_op.region(),
1947                                             /*is_fusion=*/true));
1948 
1949     // Shape of the dynamic-update-slice's "update" operand.
1950     Shape update_shape =
1951         fused_computation->root_instruction()->operand(1)->shape();
1952 
1953     TF_ASSIGN_OR_RETURN(
1954         LaunchDimensions launch_dimensions,
1955         CalculateLaunchDimensions(update_shape,
1956                                   ir_emitter_context_->gpu_device_info()));
1957 
1958     // Set up kernel thunk and fused ir emitter.
1959     std::vector<llvm_ir::IrArray> ir_arrays;
1960     TF_ASSIGN_OR_RETURN(auto fusion_thunk,
1961                         BuildKernelThunk(fusion_op, GetThunkInfo(op),
1962                                          &ir_arrays, launch_dimensions));
1963     AddThunkToThunkSequence(std::move(fusion_thunk));
1964 
1965     GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
1966                                             ir_emitter_context_->llvm_module(),
1967                                             &b_, GetNestedComputer());
1968 
1969     FusedIrEmitter fused_emitter(&elemental_emitter);
1970 
1971     for (int i = 0; i < fused_computation->num_parameters(); i++) {
1972       auto fused_operand = fused_computation->parameter_instruction(i);
1973       fused_emitter.BindGenerator(
1974           fused_operand, [this, &ir_arrays, i,
1975                           fused_operand](const llvm_ir::IrArray::Index& index) {
1976             return ir_arrays[i].EmitReadArrayElement(index, &b_,
1977                                                      fused_operand->name());
1978           });
1979     }
1980 
1981     // Array to write into.  Because this is an in-place operation, this is the
1982     // same as operand 0's array.
1983     const IrArray& output_array = ir_arrays.back();
1984 
1985     return llvm_ir::EmitParallelFusedDynamicUpdateSliceInPlace(
1986         fused_computation, output_array, &fused_emitter, launch_dimensions,
1987         &b_);
1988   }
1989 
1990   if (auto copy = mlir::dyn_cast<mlir::mhlo::CopyOp>(fusion_root)) {
1991     if (IsSingleInstructionFusion(fusion_op)) {
1992       auto operands = GetHloOperands(fusion_op);
1993       auto outputs = GetHloOutputs(fusion_op);
1994       TF_RET_CHECK(operands.size() == 1);
1995       TF_RET_CHECK(outputs.size() == 1);
1996 
1997       auto operand_shape = GetShape(operands[0]);
1998       auto output_shape = GetShape(outputs[0]);
1999 
2000       CHECK(ShapeUtil::Compatible(operand_shape, output_shape));
2001       auto maybe_slice = GetAllocationSlice(operands[0]);
2002       if (LayoutUtil::Equal(operand_shape.layout(), output_shape.layout()) &&
2003           maybe_slice.ok()) {
2004         // Copy the operand into the output if it's not the same buffer already.
2005         auto operand_buffer = *maybe_slice;
2006         auto destination_buffer = *GetAllocationSlice(outputs[0]);
2007         if (operand_buffer != destination_buffer) {
2008           AddThunkToThunkSequence(absl::make_unique<DeviceToDeviceCopyThunk>(
2009               GetThunkInfo(op),
2010               /*source_address=*/operand_buffer,
2011               /*destination_buffer=*/destination_buffer,
2012               /*mem_size=*/
2013               ByteSizeOf(operand_shape)));
2014         }
2015         return Status::OK();
2016       }
2017     }
2018   }
2019 
2020   TF_ASSIGN_OR_RETURN(const bool matched_021, CheckAndEmitHloWithTile021(op));
2021   if (matched_021) {
2022     return Status::OK();
2023   }
2024 
2025   return EmitLoopFusion(op);
2026 }
2027 
EmitExtraOutputsForReduce(absl::Span<const llvm_ir::IrArray> result_ir_arrays,const IrArray::Index & index,bool use_linear_index,absl::Span<const std::pair<llvm_ir::ElementGenerator,int>> extra_output_gens)2028 Status IrEmitterUnnested::EmitExtraOutputsForReduce(
2029     absl::Span<const llvm_ir::IrArray> result_ir_arrays,
2030     const IrArray::Index& index, bool use_linear_index,
2031     absl::Span<const std::pair<llvm_ir::ElementGenerator, int>>
2032         extra_output_gens) {
2033   // Compute all extra output values before writing them. This avoids
2034   // overwriting aliased input/output buffers before all reads occured.
2035   absl::InlinedVector<llvm::Value*, 8> extra_output_ir_values;
2036   for (int i = 0; i < extra_output_gens.size(); ++i) {
2037     TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value,
2038                         extra_output_gens[i].first(index));
2039     extra_output_ir_values.push_back(extra_output_ir_value);
2040   }
2041   for (int i = 0; i < extra_output_gens.size(); ++i) {
2042     result_ir_arrays[extra_output_gens[i].second].EmitWriteArrayElement(
2043         index, extra_output_ir_values[i], &b_, use_linear_index);
2044   }
2045   return Status::OK();
2046 }
2047 
AssertNonDeterminismIsOkay(const string & op_name)2048 Status IrEmitterUnnested::AssertNonDeterminismIsOkay(const string& op_name) {
2049   if (hlo_module_config_.debug_options().xla_gpu_deterministic_ops()) {
2050     return Unimplemented(
2051         "HLO instruction %s does not have a deterministic implementation, "
2052         "but run-to-run determinism is required by "
2053         "--xla_gpu_deterministic_ops.",
2054         op_name);
2055   }
2056   return Status::OK();
2057 }
2058 
EmitSelectAndScatter(mlir::Operation * op)2059 Status IrEmitterUnnested::EmitSelectAndScatter(mlir::Operation* op) {
2060   auto select_and_scatter_op = mlir::cast<mlir::lmhlo::SelectAndScatterOp>(op);
2061 
2062   const Shape source_shape = GetShape(select_and_scatter_op.source());
2063   const Shape operand_shape = GetShape(select_and_scatter_op.operand());
2064   const int64_t rank = operand_shape.rank();
2065 
2066   CHECK_EQ(rank, source_shape.rank());
2067   if (select_and_scatter_op.window_dimensions()) {
2068     CHECK_EQ(rank, select_and_scatter_op.window_dimensions()->size());
2069   }
2070 
2071   TF_RETURN_IF_ERROR(AssertNonDeterminismIsOkay(
2072       mlir::GetNameFromLoc(select_and_scatter_op.getLoc())));
2073 
2074   std::string name = mlir::GetNameFromLoc(select_and_scatter_op.getLoc());
2075 
2076   // IrEmitterUnnested implements kSelectAndScatter as a SequentialThunk
2077   // consisting of two thunks, an initializer KernelThunk that initializes
2078   // the output and another KernelThunk that accumulates the scattered
2079   // elements.
2080   ThunkSequence thunks;
2081   thunks.emplace_back();
2082   TF_ASSIGN_OR_RETURN(thunks.back(), BuildInitializerThunk(
2083                                          op, select_and_scatter_op.init_value(),
2084                                          select_and_scatter_op.out()));
2085 
2086   TF_ASSIGN_OR_RETURN(
2087       LaunchDimensions launch_dimensions,
2088       CalculateLaunchDimensions(source_shape,
2089                                 ir_emitter_context_->gpu_device_info()));
2090   std::vector<llvm_ir::IrArray> ir_arrays;
2091   thunks.emplace_back();
2092   // Init value is not needed in IR emission.
2093   TF_ASSIGN_OR_RETURN(
2094       thunks.back(),
2095       BuildKernelThunk(
2096           select_and_scatter_op,
2097           {select_and_scatter_op.operand(), select_and_scatter_op.source(),
2098            select_and_scatter_op.out()},
2099           Thunk::ThunkInfo(), &ir_arrays, launch_dimensions));
2100 
2101   CHECK_EQ(ir_arrays.size(), 3);
2102   const IrArray& operand_array = ir_arrays[0];
2103   const IrArray& source_array = ir_arrays[1];
2104   const IrArray& out_array = ir_arrays[2];
2105 
2106   auto select_and_scatter_thunk =
2107       absl::make_unique<SequentialThunk>(GetThunkInfo(op), std::move(thunks));
2108 
2109   llvm::Type* index_type = GetIndexTypeForKernel(
2110       select_and_scatter_op, launch_dimensions.launch_bound(), &b_);
2111   auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
2112     return llvm::ConstantInt::get(index_type, c);
2113   };
2114 
2115   // kSelectAndScatter is implemented as two kernel launches: the first launch
2116   // initializes the output array to the given initial value,
2117   // and the second accumulates the "source" matrix to the
2118   // selected elements in the output array. The first launch is already
2119   // implemented by the initializer thunk generated earlier, so this function
2120   // only needs to take care of the select-and-scatter part.
2121   //
2122   // Pseudo code for select-and-scatter:
2123   //
2124   // for (coordinates S in the source):  # This loop is parallel.
2125   //   initialized_flag = false
2126   //   for (coordinates W in the window):
2127   //     I = S * stride + W - pad_low
2128   //     if I within bounds of operand:
2129   //       if !(initialized_flag and select(selected_value, operand(I))):
2130   //         selected_value = operand(I)
2131   //         selected_index = I
2132   //         initialized_flag = true
2133   //   output(selected_index) = scatter(output(selected_index), source(S))
2134   auto loop_body_emitter = [&](const IrArray::Index& source_index) -> Status {
2135     // Allocate space to keep the currently selected value, its index, and a
2136     // boolean flag if the value is initialized. The initialized_flag is set
2137     // false.
2138     llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry(
2139         llvm_ir::PrimitiveTypeToIrType(operand_shape.element_type(),
2140                                        ir_emitter_context_->llvm_module()),
2141         "selected_value_address", &b_);
2142 
2143     llvm::Value* selected_index_address =
2144         llvm_ir::EmitAllocaAtFunctionEntryWithCount(
2145             index_type, index_typed_constant(rank), "selected_index_address",
2146             &b_);
2147 
2148     llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry(
2149         b_.getInt1Ty(), "initialized_flag_address", &b_);
2150     Store(b_.getInt1(false), initialized_flag_address);
2151 
2152     // Create the inner loop to iterate over the window.
2153     llvm_ir::ForLoopNest window_loops(absl::StrCat(name, "inner"), &b_,
2154                                       index_type);
2155 
2156     DimensionVector window_size;
2157     mlir::DenseIntElementsAttr window_dimensions =
2158         select_and_scatter_op.window_dimensions().getValue();
2159     for (const auto& dim : window_dimensions) {
2160       window_size.push_back(dim.getSExtValue());
2161       CHECK_GT(dim.getSExtValue(), 0);
2162     }
2163 
2164     const IrArray::Index window_index = window_loops.AddLoopsForShape(
2165         ShapeUtil::MakeShape(operand_shape.element_type(), window_size),
2166         "window");
2167     llvm_ir::SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(),
2168                                    &b_);
2169 
2170     // Compute the operand index to visit and evaluate the condition whether the
2171     // operand index is within the bounds. The unsigned comparison includes
2172     // checking whether the operand index >= 0.
2173     std::vector<llvm::Value*> operand_multi_index(source_index.size());
2174     llvm::Value* in_bounds_condition = b_.getInt1(true);
2175 
2176     auto strides = *select_and_scatter_op.window_strides();
2177     auto paddings = *select_and_scatter_op.padding();
2178 
2179     for (auto stride_and_padding :
2180          llvm::enumerate(llvm::zip(strides, paddings))) {
2181       const int i = stride_and_padding.index();
2182       int64_t stride = std::get<0>(stride_and_padding.value()).getSExtValue();
2183       int64_t padding = std::get<1>(stride_and_padding.value()).getSExtValue();
2184 
2185       llvm::Value* strided_index =
2186           NSWMul(source_index[i], index_typed_constant(stride));
2187       operand_multi_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]),
2188                                       index_typed_constant(padding));
2189       llvm::Value* index_condition = ICmpULT(
2190           operand_multi_index[i],
2191           index_typed_constant(ShapeUtil::GetDimension(operand_shape, i)));
2192       in_bounds_condition = And(in_bounds_condition, index_condition);
2193     }
2194 
2195     // Only need to do something if the operand index is within the bounds.
2196     // First check if the initialized_flag is set.
2197     llvm_ir::LlvmIfData if_in_bounds =
2198         llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
2199     llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, &b_);
2200     llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse(
2201         Load(initialized_flag_address), "initialized", &b_);
2202 
2203     // If the initialized_flag is false, initialize the selected value and index
2204     // with the currently visiting operand.
2205     llvm_ir::SetToFirstInsertPoint(if_initialized.false_block, &b_);
2206     const auto save_operand_index = [&](const IrArray::Index& operand_index) {
2207       for (int64_t i = 0; i < rank; ++i) {
2208         llvm::Value* selected_index_address_slot =
2209             InBoundsGEP(selected_index_address, {b_.getInt32(i)});
2210         Store(operand_index[i], selected_index_address_slot);
2211       }
2212     };
2213     IrArray::Index operand_index(operand_multi_index, operand_shape,
2214                                  index_type);
2215     llvm::Value* operand_data =
2216         operand_array.EmitReadArrayElement(operand_index, &b_);
2217     Store(operand_data, selected_value_address);
2218     save_operand_index(operand_index);
2219     Store(b_.getInt1(true), initialized_flag_address);
2220 
2221     // If the initialized_flag is true, call the `select` function to
2222     // potentially update the selected value and index with the currently
2223     // visiting operand.
2224     llvm_ir::SetToFirstInsertPoint(if_initialized.true_block, &b_);
2225     llvm::Value* operand_address =
2226         operand_array.EmitArrayElementAddress(operand_index, &b_);
2227     llvm::Value* select_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
2228         llvm_ir::PrimitiveTypeToIrType(PRED,
2229                                        ir_emitter_context_->llvm_module()),
2230         "select_return_buffer", &b_);
2231 
2232     TF_ASSIGN_OR_RETURN(
2233         const HloComputation* select_computation,
2234         GetOrCreateSubComputationFromRegion(&select_and_scatter_op.select(),
2235                                             /*is_fusion=*/false));
2236 
2237     TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
2238         *select_computation, {selected_value_address, operand_address},
2239         select_return_buffer));
2240     llvm::Value* result = Load(select_return_buffer);
2241 
2242     // If the 'select' function returns false, update the selected value and the
2243     // index to the currently visiting operand.
2244     llvm::Value* cond = ICmpNE(
2245         result,
2246         llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(
2247                                    PRED, ir_emitter_context_->llvm_module()),
2248                                0),
2249         "boolean_predicate");
2250     llvm_ir::LlvmIfData if_select_lhs =
2251         llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_);
2252     llvm_ir::SetToFirstInsertPoint(if_select_lhs.false_block, &b_);
2253     Store(Load(operand_address), selected_value_address);
2254     save_operand_index(operand_index);
2255 
2256     // After iterating over the window elements, scatter the source element to
2257     // the selected index of the output. The value we store at the output
2258     // location is computed by calling the `scatter` function with the source
2259     // value and the current output value.
2260     llvm_ir::SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(),
2261                                    &b_);
2262     std::vector<llvm::Value*> selected_multi_index;
2263     for (int64_t i = 0; i < rank; ++i) {
2264       llvm::Value* selected_index_address_slot =
2265           InBoundsGEP(selected_index_address, {b_.getInt32(i)});
2266       selected_multi_index.push_back(Load(selected_index_address_slot));
2267     }
2268     const Shape output_shape = GetShape(select_and_scatter_op.out());
2269     llvm::Value* source_value_address =
2270         source_array.EmitArrayElementAddress(source_index, &b_);
2271     IrArray::Index selected_index(selected_multi_index, output_shape,
2272                                   operand_index.GetType());
2273     llvm::Value* output_value_address =
2274         out_array.EmitArrayElementAddress(selected_index, &b_);
2275 
2276     TF_ASSIGN_OR_RETURN(
2277         const HloComputation* scatter_computation,
2278         GetOrCreateSubComputationFromRegion(&select_and_scatter_op.scatter(),
2279                                             /*is_fusion=*/false));
2280 
2281     return EmitAtomicOperationForNestedComputation(
2282         *scatter_computation, output_value_address, source_value_address);
2283   };
2284 
2285   AddThunkToThunkSequence(std::move(select_and_scatter_thunk));
2286   return ParallelLoopEmitter(loop_body_emitter, source_shape, launch_dimensions,
2287                              &b_)
2288       .EmitLoop(name, index_type);
2289 }
2290 
EmitWhile(mlir::Operation * op)2291 Status IrEmitterUnnested::EmitWhile(mlir::Operation* op) {
2292   auto while_op = mlir::cast<mlir::lmhlo::WhileOp>(op);
2293 
2294   auto cond_result = GetHloOutputs(while_op);
2295   TF_RET_CHECK(cond_result.size() == 1);
2296   TF_RET_CHECK(cond_result[0]
2297                    .getType()
2298                    .cast<mlir::ShapedType>()
2299                    .getElementType()
2300                    .isInteger(/*width=*/1))
2301       << "While condition computation must return bool";
2302 
2303   //  Build ForThunk for conformant while loops, otherwise build WhileThunk.
2304   if (while_op.trip_count()) {
2305     TF_ASSIGN_OR_RETURN(auto thunk, BuildForThunk(while_op, GetThunkInfo(op),
2306                                                   *while_op.trip_count()));
2307     AddThunkToThunkSequence(std::move(thunk));
2308   } else {
2309     TF_ASSIGN_OR_RETURN(auto thunk,
2310                         BuildWhileThunk(while_op, GetThunkInfo(op)));
2311     AddThunkToThunkSequence(std::move(thunk));
2312   }
2313   return Status::OK();
2314 }
2315 
EmitRngGetAndUpdateState(mlir::Operation * op)2316 Status IrEmitterUnnested::EmitRngGetAndUpdateState(mlir::Operation* op) {
2317   auto rng_op = mlir::dyn_cast<mlir::lmhlo::RngGetAndUpdateStateOp>(op);
2318 
2319   // Emit a kernel to increment the global state for Philox RNG algorithm.
2320   std::vector<llvm_ir::IrArray> ir_arrays;
2321   TF_ASSIGN_OR_RETURN(auto kernel_thunk,
2322                       BuildKernelThunk(rng_op, rng_op.state(), GetThunkInfo(op),
2323                                        &ir_arrays, LaunchDimensions()));
2324   AddThunkToThunkSequence(std::move(kernel_thunk));
2325 
2326   llvm::Value* old_state =
2327       llvm_ir::RngGetAndUpdateState(rng_op.delta(), module_, &b_);
2328 
2329   const Shape shape = GetShape(rng_op.state());
2330 
2331   llvm::Value* output_address = ir_arrays[0].EmitArrayElementAddress(
2332       llvm_ir::IrArray::Index(
2333           /*linear=*/b_.getInt64(0), shape, &b_),
2334       &b_, "rng_state_address");
2335   output_address = BitCast(
2336       output_address, llvm::PointerType::get(
2337                           old_state->getType(),
2338                           output_address->getType()->getPointerAddressSpace()));
2339   Store(old_state, output_address);
2340 
2341   return Status::OK();
2342 }
2343 
EmitScatter(mlir::Operation * op)2344 Status IrEmitterUnnested::EmitScatter(mlir::Operation* op) {
2345   ThunkSequence thunks;
2346 
2347   auto scatter_op = mlir::cast<mlir::lmhlo::ScatterOp>(op);
2348 
2349   if (!scatter_op.unique_indices()) {
2350     TF_RETURN_IF_ERROR(
2351         AssertNonDeterminismIsOkay(mlir::GetNameFromLoc(scatter_op.getLoc())));
2352   }
2353 
2354   TF_ASSIGN_OR_RETURN(auto operand_buffer,
2355                       GetAllocationSlice(scatter_op.operand()));
2356   TF_ASSIGN_OR_RETURN(auto output_buffer,
2357                       GetAllocationSlice(scatter_op.output()));
2358 
2359   // Copy the operand into the output if it's not the same buffer already.
2360   if (operand_buffer != output_buffer) {
2361     thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
2362         Thunk::ThunkInfo(),
2363         /*source_address=*/operand_buffer,
2364         /*destination_buffer=*/output_buffer,
2365         /*mem_size=*/
2366         ShapeUtil::ByteSizeOf(GetShape(scatter_op.output()))));
2367   }
2368 
2369   const Shape& data_shape = GetShape(scatter_op.updates());
2370   TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions,
2371                       CalculateLaunchDimensions(
2372                           data_shape, ir_emitter_context_->gpu_device_info()));
2373 
2374   // Create kernel thunk for all operands except the first one (`operand`). The
2375   // code generated for scatter below assumes that the input operand is already
2376   // copied into the output, so does not use it in codegen.
2377   std::vector<llvm_ir::IrArray> ir_arrays;
2378   thunks.emplace_back();
2379   TF_ASSIGN_OR_RETURN(
2380       thunks.back(),
2381       BuildKernelThunk(scatter_op, scatter_op.getOperands().drop_front(),
2382                        GetThunkInfo(op), &ir_arrays, launch_dimensions));
2383 
2384   CHECK_EQ(ir_arrays.size(), 3);
2385   const IrArray& scatter_indices = ir_arrays[0];
2386   const IrArray& updates = ir_arrays[1];
2387   const IrArray& output = ir_arrays[2];
2388 
2389   auto get_index_type = [&](int64_t launch_size) {
2390     return GetIndexTypeForKernel(scatter_op, launch_size, &b_);
2391   };
2392 
2393   TF_RETURN_IF_ERROR(EmitScatter(
2394       thunks.back().get(), scatter_op, launch_dimensions, output,
2395       /*scatter_indices_gen=*/
2396       [&](const IrArray::Index& index) {
2397         return scatter_indices.EmitReadArrayElement(index, &b_,
2398                                                     "scatter_index");
2399       },
2400       /*updates_gen=*/
2401       [&](const IrArray::Index& index) {
2402         return updates.EmitReadArrayElement(index, &b_, "update");
2403       },
2404       /* get_index_type=*/
2405       get_index_type));
2406 
2407   // Elide the sequential thunk if there's no copy.
2408   if (thunks.size() == 1) {
2409     AddThunkToThunkSequence(std::move(thunks[0]));
2410   } else {
2411     AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
2412         GetThunkInfo(op), std::move(thunks)));
2413   }
2414 
2415   return Status::OK();
2416 }
2417 
EmitScatter(Thunk * thunk,mlir::lmhlo::ScatterOp scatter,const LaunchDimensions & launch_dimensions,const llvm_ir::IrArray & output,const llvm_ir::ElementGenerator & scatter_indices_gen,const llvm_ir::ElementGenerator & updates_gen,std::function<llvm::Type * (int64_t)> get_index_type)2418 Status IrEmitterUnnested::EmitScatter(
2419     Thunk* thunk, mlir::lmhlo::ScatterOp scatter,
2420     const LaunchDimensions& launch_dimensions, const llvm_ir::IrArray& output,
2421     const llvm_ir::ElementGenerator& scatter_indices_gen,
2422     const llvm_ir::ElementGenerator& updates_gen,
2423     std::function<llvm::Type*(int64_t)> get_index_type) {
2424   const Shape operand_shape = GetShape(scatter.operand());
2425   CHECK(ShapeUtil::Equal(GetShape(scatter.output()), operand_shape));
2426 
2427   TF_ASSIGN_OR_RETURN(
2428       const HloComputation* update_computation,
2429       GetOrCreateSubComputationFromRegion(&scatter.update_computation(),
2430                                           /*is_fusion=*/false));
2431 
2432   ScatterDescriptor desc;
2433   desc.name = mlir::GetNameFromLoc(scatter.getLoc());
2434   desc.operand_shape = operand_shape;
2435   desc.scatter_indices_shape = GetShape(scatter.scatter_indices());
2436   desc.updates_shape = GetShape(scatter.updates());
2437   desc.dim_numbers = scatter.scatter_dimension_numbers();
2438   desc.unique_indices = scatter.unique_indices();
2439   desc.update_computation = update_computation;
2440   desc.output = output;
2441   desc.scatter_indices_gen = scatter_indices_gen;
2442   desc.updates_gen = updates_gen;
2443   desc.get_index_type = get_index_type;
2444   return EmitScatter(desc, thunk, launch_dimensions);
2445 }
2446 
EmitScatter(const ScatterDescriptor & desc,Thunk * thunk,const LaunchDimensions & launch_dimensions)2447 Status IrEmitterUnnested::EmitScatter(
2448     const ScatterDescriptor& desc, Thunk* thunk,
2449     const LaunchDimensions& launch_dimensions) {
2450   if (!desc.unique_indices) {
2451     TF_RETURN_IF_ERROR(AssertNonDeterminismIsOkay(desc.name));
2452   }
2453   auto loop_body_emitter = [&](const IrArray::Index& index) -> Status {
2454     std::vector<llvm::Value*> raw_window_multidim;
2455     std::vector<llvm::Value*> input_scatter_multidim;
2456     std::vector<int64> raw_window_bounds;
2457 
2458     // Partition the index into window indices and scatter indices.
2459     for (int64_t i = 0, e = index.size(); i != e; ++i) {
2460       // For window indices also remember the window size, this comes in handy
2461       // later.
2462       if (BinarySearchDenseElementsAttr(desc.dim_numbers.update_window_dims(),
2463                                         i)) {
2464         raw_window_multidim.push_back(index[i]);
2465         raw_window_bounds.push_back(desc.updates_shape.dimensions(i));
2466       } else {
2467         input_scatter_multidim.push_back(index[i]);
2468       }
2469     }
2470     DCHECK_EQ(raw_window_multidim.size(),
2471               desc.dim_numbers.update_window_dims().size());
2472 
2473     // Apply inserted_window_dims to the window dimensions.
2474     int64_t raw_window_multidim_idx = 0;
2475     std::vector<llvm::Value*> input_window_multidim;
2476     std::vector<int64> input_window_bounds;
2477 
2478     for (int64_t i = 0, e = desc.operand_shape.rank(); i != e; ++i) {
2479       if (BinarySearchDenseElementsAttr(desc.dim_numbers.inserted_window_dims(),
2480                                         i)) {
2481         input_window_bounds.push_back(1);  // Trivial dimension.
2482         input_window_multidim.push_back(index.GetConstantWithIndexType(0));
2483       } else {
2484         input_window_bounds.push_back(
2485             raw_window_bounds[raw_window_multidim_idx]);
2486         input_window_multidim.push_back(
2487             raw_window_multidim[raw_window_multidim_idx]);
2488         ++raw_window_multidim_idx;
2489       }
2490     }
2491     DCHECK_EQ(input_window_multidim.size(), desc.operand_shape.rank());
2492 
2493     // Insert a 1 dimension at the end if index_vector_dim requests one.
2494     Shape scatter_indices_shape_fixed = desc.scatter_indices_shape;
2495     if (desc.dim_numbers.index_vector_dim().getInt() ==
2496         desc.scatter_indices_shape.rank()) {
2497       scatter_indices_shape_fixed.add_dimensions(1);
2498       scatter_indices_shape_fixed.mutable_layout()->add_minor_to_major(
2499           desc.dim_numbers.index_vector_dim().getInt());
2500     }
2501 
2502     // Now load the indices corresponding to the current window from
2503     // scatter_indices.
2504     std::vector<llvm::Value*> raw_scatter_index_multidim =
2505         input_scatter_multidim;
2506     raw_scatter_index_multidim.insert(
2507         raw_scatter_index_multidim.begin() +
2508             desc.dim_numbers.index_vector_dim().getInt(),
2509         nullptr);
2510     llvm::Value* is_in_bounds = b_.getTrue();
2511     for (int64_t i = 0,
2512                  e = desc.dim_numbers.scatter_dims_to_operand_dims().size();
2513          i != e; ++i) {
2514       // Our index is stored along index_vector_dim, insert that into the lookup
2515       // index into scatter_indices.
2516       raw_scatter_index_multidim[desc.dim_numbers.index_vector_dim().getInt()] =
2517           index.GetConstantWithIndexType(i);
2518       llvm_ir::IrArray::Index raw_scatter_index_index(
2519           raw_scatter_index_multidim, scatter_indices_shape_fixed,
2520           index.GetType());
2521 
2522       int64_t operand_dim =
2523           desc.dim_numbers.scatter_dims_to_operand_dims().getValue<int64>(i);
2524       TF_ASSIGN_OR_RETURN(
2525           llvm::Value* const loaded_scatter_index,
2526           desc.scatter_indices_gen(raw_scatter_index_index.SourceIndexOfReshape(
2527               scatter_indices_shape_fixed, desc.scatter_indices_shape, &b_)));
2528       // And add the index to our window index. This yields the output index.
2529       llvm::Value* casted_scatter_index =
2530           IntCast(loaded_scatter_index, index.GetType(),
2531                   /*isSigned=*/true);
2532       llvm::Value* dim_offset =
2533           Add(input_window_multidim[operand_dim], casted_scatter_index);
2534       input_window_multidim[operand_dim] = dim_offset;
2535 
2536       // Also do the bounds check now.
2537       int64_t max_index = desc.operand_shape.dimensions(operand_dim) -
2538                           input_window_bounds[operand_dim] + 1;
2539       // is_in_bounds = index >= 0 && index < dim_size-window_size+1
2540       //   --> index u< dim_size-window_size+1
2541       is_in_bounds =
2542           And(is_in_bounds, ICmpULT(casted_scatter_index,
2543                                     index.GetConstantWithIndexType(max_index)));
2544     }
2545 
2546     llvm_ir::LlvmIfData if_window_in_bounds_data = llvm_ir::EmitIfThenElse(
2547         is_in_bounds, "scatter.in_bounds", &b_, /*emit_else=*/false);
2548     llvm_ir::SetToFirstInsertPoint(if_window_in_bounds_data.true_block, &b_);
2549     // All done, now just read from the calculated input from the window, and do
2550     // an atomic store to the calculated location in the output.
2551     llvm_ir::IrArray::Index input_window_index(
2552         input_window_multidim, desc.output.GetShape(), index.GetType());
2553     llvm::Value* output_address =
2554         desc.output.EmitArrayElementAddress(input_window_index, &b_);
2555     llvm::Value* input_address = llvm_ir::EmitAllocaAtFunctionEntry(
2556         llvm_ir::PrimitiveTypeToIrType(desc.updates_shape.element_type(),
2557                                        module_),
2558         "input_address", &b_);
2559     TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
2560                         desc.updates_gen(index));
2561     Store(input_ir_value, input_address);
2562 
2563     if (!desc.unique_indices) {
2564       return EmitAtomicOperationForNestedComputation(
2565           *desc.update_computation, output_address, input_address);
2566     } else {
2567       return EmitCallToNestedComputation(*desc.update_computation,
2568                                          {output_address, input_address},
2569                                          output_address);
2570     }
2571   };
2572 
2573   // Launch a kernel that reads every element in the updates tensor. We could
2574   // also do one kernel per window instead if bounds checks turn out to be a
2575   // bottleneck.
2576   return ParallelLoopEmitter(loop_body_emitter, desc.updates_shape,
2577                              launch_dimensions, &b_)
2578       .EmitLoop(desc.name,
2579                 desc.get_index_type(launch_dimensions.launch_bound()));
2580 }
2581 
2582 // This transformation should be migrated off. See b/171334474.
2583 StatusOr<HloComputation*>
GetOrCreateSubComputationFromRegion(mlir::Region * region,bool is_fusion)2584 IrEmitterUnnested::GetOrCreateSubComputationFromRegion(mlir::Region* region,
2585                                                        bool is_fusion) {
2586   std::unique_ptr<HloModule>& module = scratch_nested_computations_[region];
2587   if (module == nullptr) {
2588     std::vector<Shape> operand_shapes, output_shapes;
2589     if (is_fusion) {
2590       mlir::Operation* clone = region->getParentOp()->clone();
2591       region = &mlir::cast<mlir::lmhlo::FusionOp>(clone).region();
2592       TF_RETURN_IF_ERROR(
2593           ProcessFusionForConversion(region, &operand_shapes, &output_shapes));
2594     }
2595 
2596     xla::XlaComputation xla_computation;
2597     mlir::MlirToHloConversionOptions options;
2598     options.propagate_layouts = true;
2599     options.propagate_bitcast_layouts_to_backend_config = true;
2600     TF_RETURN_IF_ERROR(
2601         ConvertRegionToComputation(region, &xla_computation, options));
2602 
2603     if (is_fusion) {
2604       region->getParentOp()->erase();
2605     }
2606 
2607     TF_ASSIGN_OR_RETURN(auto program_shape, xla_computation.GetProgramShape());
2608     TF_ASSIGN_OR_RETURN(
2609         module, HloModule::CreateFromProto(xla_computation.proto(),
2610                                            HloModuleConfig(program_shape)));
2611 
2612     if (is_fusion) {
2613       HloComputation* fused_computation = module->entry_computation();
2614 
2615       CHECK_EQ(operand_shapes.size(), fused_computation->num_parameters());
2616       for (int i = 0; i < fused_computation->num_parameters(); i++) {
2617         *fused_computation->parameter_instruction(i)
2618              ->mutable_shape()
2619              ->mutable_layout() = operand_shapes[i].layout();
2620       }
2621       HloInstruction* root = fused_computation->root_instruction();
2622       // Manually fold Tuple(GTE(a, 0), GTE(a, 1), GTE(a, 2), ...) to a.
2623       // FusedIrEmitter doesn't take GTE ops because we aim to elimiate tuples
2624       // as much as possible.
2625       if (root->opcode() == HloOpcode::kTuple) {
2626         [&] {
2627           HloInstruction* real_root = nullptr;
2628           int expected_tuple_index = 0;
2629           for (HloInstruction* operand : root->operands()) {
2630             if (operand->opcode() != HloOpcode::kGetTupleElement) {
2631               return;
2632             }
2633             if (real_root == nullptr) {
2634               real_root = operand->mutable_operand(0);
2635             } else if (real_root != operand->operand(0)) {
2636               return;
2637             }
2638             if (expected_tuple_index != operand->tuple_index()) {
2639               return;
2640             }
2641             expected_tuple_index++;
2642           }
2643           fused_computation->set_root_instruction(real_root);
2644           std::vector<HloInstruction*> to_be_removed;
2645           to_be_removed.push_back(root);
2646           for (HloInstruction* operand : root->operands()) {
2647             to_be_removed.push_back(operand);
2648           }
2649           for (auto instr : to_be_removed) {
2650             TF_CHECK_OK(fused_computation->RemoveInstruction(instr));
2651           }
2652 
2653           root = real_root;
2654         }();
2655       }
2656 
2657       if (output_shapes.size() > 1) {
2658         CHECK(root->shape().IsTuple());
2659         CHECK_EQ(root->shape().tuple_shapes_size(), output_shapes.size());
2660 
2661         for (int i = 0; i < output_shapes.size(); i++) {
2662           *root->mutable_shape()->mutable_tuple_shapes(i) = output_shapes.at(i);
2663         }
2664       } else {
2665         CHECK_EQ(1, output_shapes.size());
2666         *root->mutable_shape() = output_shapes[0];
2667       }
2668     }
2669     // Post-process the generated computation:
2670     // * Sanitize constant names, so that they can be used as LLVM global
2671     // symbols.
2672     // * Propagate layouts for tuple types.
2673     for (HloComputation* computation : module->computations()) {
2674       for (HloInstruction* instr : computation->MakeInstructionPostOrder()) {
2675         if (instr->opcode() == HloOpcode::kConstant) {
2676           // Notice that IR emitters use the name of constants as LLVM symbol
2677           // names, therefore it's important to not let these constants in the
2678           // new module collide with constants in the original module by names.
2679           // Unique them by prepending the module name.
2680           //
2681           // TODO(timshen): A better solution would be to plumb the exact
2682           // constant names through original HLO -> LHLO -> MHLO -> HLO. This is
2683           // hard because XLA builder doesn't support setting names. Revisit
2684           // this once we get rid of this function, or don't rely on the op name
2685           // (which shouldn't be the identity) to generate LLVM symbols.
2686           instr->SetAndSanitizeName(llvm_ir::SanitizeConstantName(
2687               module->name() + "_" + instr->name()));
2688         }
2689         if (instr->shape().IsTuple() &&
2690             computation == module->entry_computation() &&
2691             instr != computation->root_instruction()) {
2692           return InternalError("Non-root tuple types are not handled.");
2693         }
2694       }
2695     }
2696   }
2697   return module->entry_computation();
2698 }
2699 
EmitSort(mlir::Operation * op)2700 Status IrEmitterUnnested::EmitSort(mlir::Operation* op) {
2701   auto sort_op = mlir::cast<mlir::lmhlo::SortOp>(op);
2702   MlirEmitterContext context;
2703   context.SetOperation(sort_op);
2704 
2705   ThunkSequence thunks;
2706 
2707   const Shape& keys_shape = context.operand_shapes[0];
2708   int64_t dimension_to_sort = sort_op.dimension();
2709   for (int64_t i = 0; i < context.operand_shapes.size(); ++i) {
2710     // We assume that the layout of all involved operands and outputs is the
2711     // same.
2712     TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(keys_shape,
2713                                                   context.operand_shapes[i]));
2714     TF_RET_CHECK(
2715         LayoutUtil::LayoutsInShapesEqual(keys_shape, context.output_shapes[i]));
2716 
2717     // If possible, we share buffers. If that is not possible, we need to copy
2718     // the values, because the emitter does the sorting in-place.
2719     TF_ASSIGN_OR_RETURN(auto destination_buffer,
2720                         GetAllocationSlice(sort_op.output()[i]));
2721     TF_ASSIGN_OR_RETURN(auto source_address,
2722                         GetAllocationSlice(sort_op.operands()[i]));
2723     if (destination_buffer != source_address) {
2724       // TODO(b/26783907): Figure out why we never seem to share buffers for
2725       // key/value sort.
2726       VLOG(2) << context.name << " requires initial D2D copy for operand " << i;
2727       thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
2728           Thunk::ThunkInfo(),
2729           /*source_address=*/source_address,
2730           /*destination_buffer=*/destination_buffer,
2731           /*mem_size=*/ShapeUtil::ByteSizeOf(context.operand_shapes[i])));
2732     }
2733   }
2734 
2735   uint64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort);
2736   int64_t num_stages = tensorflow::Log2Ceiling(dimension_to_sort_bound);
2737   VLOG(2) << context.name << " requires " << num_stages << " stages.";
2738   CHECK_GE(1ULL << num_stages, dimension_to_sort_bound);
2739   CHECK_LT(1ULL << (num_stages - 1), dimension_to_sort_bound);
2740 
2741   // Naive C++ code for the outer loops:
2742   //
2743   // for (int64 stage = 0; stage < Log2Ceiling(dimension_to_sort_bound);
2744   //     ++stage) {
2745   //   int64 first_xor_mask = (1LL << (stage + 1)) - 1;
2746   //   SortInPlace(first_xor_mask);
2747   //   for (int64 mask = stage - 1; mask >= 0; --mask) {
2748   //     int64 later_xor_mask = 1LL << mask;
2749   //     SortInPlace(later_xor_mask);
2750   //   }
2751   // }
2752   //
2753   // This follows the alternative representation of the algorithm described on
2754   // Wikipedia: https://en.wikipedia.org/wiki/Bitonic_sorter
2755   //
2756   // Each mask specifies how to derive from one position in the array the
2757   // position with which it should be compared (we calculate the xor of the
2758   // position with the mask).
2759   // As an optimization, we can move the 'mask' loop to inside the
2760   // sorting/comparison loop if the comparisons happen within a small block of
2761   // the array. To make this work, we collect all consecutive masks that are
2762   // smaller than our chosen power of 2 tile size, and pass them to SortInPlace.
2763   // Each thread then processes one tile of data.
2764 
2765   const uint64 kTileSize = std::min(2048ULL, 1ULL << num_stages);
2766 
2767   // If we cannot combine several xor masks together, we don't use tiling, so we
2768   // calculate the standard launch dimensions for the shape. However we only
2769   // need to iterate through ~half of the dimension to sort (rounded up to the
2770   // next highest power of 2), because each iteration compares one pair of
2771   // elements.
2772   Shape standard_iteration_shape = keys_shape;
2773   uint64 standard_num_iterations_in_sort_dim = 1ULL << (num_stages - 1);
2774   standard_iteration_shape.set_dimensions(dimension_to_sort,
2775                                           standard_num_iterations_in_sort_dim);
2776   TF_ASSIGN_OR_RETURN(
2777       LaunchDimensions standard_launch_dimensions,
2778       CalculateLaunchDimensions(standard_iteration_shape,
2779                                 ir_emitter_context_->gpu_device_info()));
2780 
2781   // Calculate the launch dimensions for the case where we use tiling. We split
2782   // the dimension that should be sorted into tiles of size 'kTileSize'. This
2783   // means we first need to round 'dimension_to_sort_bound' up to be a multiple
2784   // of the tile size.
2785   int64_t rounded_bound = RoundUpToNearest(dimension_to_sort_bound, kTileSize);
2786   Shape iteration_shape = keys_shape;
2787 
2788   // We iterate through the element pairs that should be compared.
2789   uint64 num_iterations_in_sort_dim = rounded_bound / 2;
2790   iteration_shape.set_dimensions(dimension_to_sort, num_iterations_in_sort_dim);
2791   uint64 num_iterations = ShapeUtil::ElementsIn(iteration_shape);
2792 
2793   // For correctness reasons we need exactly 'kTileSize' / 2 many threads per
2794   // block. Each thread is responsible for copying exactly two adjacent elements
2795   // into shared memory, and then does a comparison of two possibly different
2796   // elements taken from shared memory.
2797   const uint64 kThreadsPerBlock = kTileSize / 2;
2798 
2799   // Check whether we should use any tiling. We might not be able to use it if
2800   // we have not enough threads, or not enough shared memory. Also it does not
2801   // give a speedup if the tile size is < 128.
2802   int64_t total_shared_memory_needed = 0;
2803   for (int64_t i = 0; i < context.operand_shapes.size(); ++i) {
2804     total_shared_memory_needed +=
2805         kTileSize * ShapeUtil::ByteSizeOfPrimitiveType(
2806                         context.operand_shapes[i].element_type());
2807   }
2808   bool no_tiling =
2809       kTileSize < 128 ||
2810       kThreadsPerBlock >
2811           ir_emitter_context_->gpu_device_info().threads_per_block_limit ||
2812       total_shared_memory_needed >
2813           ir_emitter_context_->gpu_device_info().shared_memory_per_block;
2814   VLOG(2) << absl::StreamFormat(
2815       "%s %s use tiling. No tiling if any of the following is true: "
2816       "kTileSize=%d < 128, "
2817       "kThreadsPerBlock=%d > threads_per_block_limit=%d, "
2818       "total_shared_memory_needed=%d > shared_memory_per_block=%d",
2819       context.name, (no_tiling ? "won't" : "will"), kTileSize, kThreadsPerBlock,
2820       ir_emitter_context_->gpu_device_info().threads_per_block_limit,
2821       total_shared_memory_needed,
2822       ir_emitter_context_->gpu_device_info().shared_memory_per_block);
2823 
2824   uint64 num_blocks = CeilOfRatio(num_iterations, kThreadsPerBlock);
2825   LaunchDimensions tiled_launch_dimensions(num_blocks, kThreadsPerBlock);
2826   VLOG(2) << absl::StreamFormat("%s launch dims: %d blocks, %d threads/block",
2827                                 context.name, num_blocks, kThreadsPerBlock);
2828 
2829   std::vector<llvm_ir::IrArray> ir_arrays;
2830   auto emit_kernel = [&](absl::Span<const int64> xor_masks) {
2831     VLOG(2) << absl::StreamFormat(
2832         "%s uses kernel for xor masks [%s]", context.name,
2833         absl::StrJoin(xor_masks, ", ", [](std::string* out, int64_t xor_mask) {
2834           absl::StrAppendFormat(out, "0x%x", xor_mask);
2835         }));
2836     thunks.emplace_back();
2837     LaunchDimensions launch_dimensions = xor_masks.size() > 1
2838                                              ? tiled_launch_dimensions
2839                                              : standard_launch_dimensions;
2840     TF_ASSIGN_OR_RETURN(
2841         thunks.back(),
2842         BuildKernelThunk(sort_op, sort_op.output(), Thunk::ThunkInfo(),
2843                          &ir_arrays, launch_dimensions));
2844     std::vector<IrArray> values_arrays;
2845     values_arrays.reserve(context.operand_shapes.size());
2846     for (int64_t i = 0; i < context.operand_shapes.size(); ++i) {
2847       values_arrays.push_back(ir_arrays[i]);
2848     }
2849     TF_ASSIGN_OR_RETURN(const HloComputation* comparator,
2850                         GetOrCreateSubComputationFromRegion(
2851                             &sort_op.comparator(), /*is_fusion=*/false));
2852     return llvm_ir::EmitSortInPlace(
2853         dimension_to_sort, values_arrays, IrName(context.name), xor_masks, &b_,
2854         launch_dimensions,
2855         xor_masks.size() > 1 ? num_iterations_in_sort_dim
2856                              : standard_num_iterations_in_sort_dim,
2857         kTileSize,
2858         [&](absl::Span<llvm::Value* const> operands, llvm::Value* output) {
2859           return EmitCallToNestedComputation(*comparator, operands, output);
2860         });
2861   };
2862   std::vector<int64> xor_masks;
2863   for (int64_t stage = 0; stage < num_stages; ++stage) {
2864     for (int64_t mask = stage; mask >= 0; --mask) {
2865       int64_t xor_mask;
2866       if (mask == stage) {
2867         xor_mask = (1LL << (stage + 1)) - 1;
2868       } else {
2869         xor_mask = 1LL << mask;
2870       }
2871       if (xor_mask >= kTileSize || no_tiling) {
2872         if (!xor_masks.empty()) {
2873           TF_RETURN_IF_ERROR(emit_kernel(xor_masks));
2874           xor_masks.clear();
2875         }
2876         TF_RETURN_IF_ERROR(emit_kernel({xor_mask}));
2877       } else {
2878         xor_masks.push_back(xor_mask);
2879       }
2880     }
2881   }
2882   if (!xor_masks.empty()) {
2883     TF_RETURN_IF_ERROR(emit_kernel(xor_masks));
2884   }
2885   VLOG(2) << absl::StreamFormat(
2886       "%s requires %d thunks (including any D2D copies)", context.name,
2887       thunks.size());
2888 
2889   AddThunkToThunkSequence(
2890       absl::make_unique<SequentialThunk>(GetThunkInfo(op), std::move(thunks)));
2891   return Status::OK();
2892 }
2893 
2894 template <typename ThunkType, typename OpT>
EmitReplicaOrPartitionId(mlir::Operation * op)2895 Status IrEmitterUnnested::EmitReplicaOrPartitionId(mlir::Operation* op) {
2896   auto casted = mlir::cast<OpT>(op);
2897   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice,
2898                       GetAllocationSlice(casted.getOperand()));
2899   AddThunkToThunkSequence(
2900       absl::make_unique<ThunkType>(GetThunkInfo(op), result_slice));
2901   return Status::OK();
2902 }
2903 
EmitCollectivePermute(mlir::Operation * op)2904 Status IrEmitterUnnested::EmitCollectivePermute(mlir::Operation* op) {
2905   auto collective_permute_op = mlir::cast<mlir::lmhlo::CollectivePermuteOp>(op);
2906 
2907   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice source_slice,
2908                       GetAllocationSlice(collective_permute_op.operand()));
2909   TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice,
2910                       GetAllocationSlice(collective_permute_op.output()));
2911 
2912   const Shape shape = GetShape(collective_permute_op.operand());
2913   const int64_t replica_count = hlo_module_config_.replica_count();
2914   const int64_t partition_count = hlo_module_config_.num_partitions();
2915 
2916   if (NcclCollectivePermuteThunk::IsDegenerate(
2917           collective_permute_op, replica_count, partition_count)) {
2918     // For a degenerate collective permute, just generate a copy thunk.
2919     AddThunkToThunkSequence(absl::make_unique<DeviceToDeviceCopyThunk>(
2920         GetThunkInfo(op),
2921         /*source_address=*/source_slice,
2922         /*destination_buffer=*/result_slice,
2923         /*mem_size=*/ShapeUtil::ByteSizeOf(shape)));
2924   } else {
2925     const NcclCollectivePermuteThunk::Buffer buffer = {
2926         /*element_count=*/ShapeUtil::ElementsIn(shape),
2927         /*source_buffer=*/source_slice,
2928         /*destination_buffer=*/result_slice};
2929     AddThunkToThunkSequence(absl::make_unique<NcclCollectivePermuteThunk>(
2930         GetThunkInfo(op), collective_permute_op, replica_count, partition_count,
2931         buffer));
2932   }
2933   return Status::OK();
2934 }
2935 
2936 template <typename NcclThunkType>
MaybeAddAllReduceStartThunkToMap(absl::flat_hash_map<mlir::Operation *,NcclAllReduceStartThunk * > &,mlir::Operation * op,NcclThunkType * thunk)2937 Status MaybeAddAllReduceStartThunkToMap(
2938     absl::flat_hash_map<mlir::Operation*, NcclAllReduceStartThunk*>&,
2939     mlir::Operation* op, NcclThunkType* thunk) {
2940   return Status::OK();
2941 }
2942 
2943 template <>
MaybeAddAllReduceStartThunkToMap(absl::flat_hash_map<mlir::Operation *,NcclAllReduceStartThunk * > & all_reduce_start_thunks,mlir::Operation * op,NcclAllReduceStartThunk * thunk)2944 Status MaybeAddAllReduceStartThunkToMap(
2945     absl::flat_hash_map<mlir::Operation*, NcclAllReduceStartThunk*>&
2946         all_reduce_start_thunks,
2947     mlir::Operation* op, NcclAllReduceStartThunk* thunk) {
2948   TF_RET_CHECK(all_reduce_start_thunks.emplace(op, thunk).second)
2949       << "all-reduce-start with this unique ID already seen";
2950   return Status::OK();
2951 }
2952 
2953 template <typename NcclThunkType, typename OpTy>
EmitNcclThunk(mlir::Operation * untyped_op)2954 Status IrEmitterUnnested::EmitNcclThunk(mlir::Operation* untyped_op) {
2955   OpTy op = mlir::cast<OpTy>(untyped_op);
2956   int64_t replica_count = hlo_module_config_.replica_count();
2957   int64_t partition_count = hlo_module_config_.num_partitions();
2958   VLOG(2) << NcclThunkType::GetName() << "; replica count: " << replica_count
2959           << "; partition count: " << partition_count
2960           << "; operand count: " << op.operands().size()
2961           << "; NCCL is enabled: " << NcclThunkType::NcclIsEnabled();
2962 
2963   // A given collective op can be degenerate if across all groups formed
2964   // by it are singleton. In such a case, we don't need to do any communication
2965   // and we can just copy the input to the output.
2966   bool is_degenerate =
2967       NcclThunkType::IsDegenerate(op, replica_count, partition_count);
2968   bool should_use_nccl_thunk =
2969       !is_degenerate && NcclThunkType::CanImplement(op);
2970 
2971   // Stash relevant information in NcclCollectiveThunk::Buffer even if we may
2972   // not generate an NcclCollectiveThunk.
2973   std::vector<NcclCollectiveThunk::Buffer> buffers;
2974   buffers.reserve(op.operands().size());
2975   for (auto it : llvm::zip(op.operands(), op.results())) {
2976     mlir::Value operand = std::get<0>(it);
2977     mlir::Value result = std::get<1>(it);
2978     const Shape shape = GetShape(operand);
2979     TF_ASSIGN_OR_RETURN(auto source_slice, GetAllocationSlice(operand));
2980     TF_ASSIGN_OR_RETURN(auto dest_slice, GetAllocationSlice(result));
2981     buffers.push_back(NcclCollectiveThunk::Buffer{
2982         /*element_count=*/ShapeUtil::ElementsIn(shape),
2983         /*source_buffer=*/source_slice,
2984         /*destination_buffer=*/dest_slice});
2985   }
2986 
2987   if (should_use_nccl_thunk) {
2988     auto nccl_thunk =
2989         absl::make_unique<NcclThunkType>(GetThunkInfo(op), op,
2990                                          /*buffers=*/std::move(buffers));
2991     // Record thunks for all-reduce-start ops as the done ops need them.
2992     TF_RETURN_IF_ERROR(MaybeAddAllReduceStartThunkToMap(
2993         all_reduce_start_thunks_, op, nccl_thunk.get()));
2994     AddThunkToThunkSequence(std::move(nccl_thunk));
2995     return Status::OK();
2996   }
2997 
2998   // Signal that all-reduce-start thunk not created with nullptr.
2999   TF_RETURN_IF_ERROR(MaybeAddAllReduceStartThunkToMap(
3000       all_reduce_start_thunks_, op, static_cast<NcclThunkType*>(nullptr)));
3001 
3002   if (!is_degenerate) {
3003     CollectiveOpGroupMode group_mode = NcclThunkType::GetGroupMode(op);
3004 
3005     string message = absl::StrFormat(
3006         "Requested %s not implemented on GPU; replica_count: %d; "
3007         "partition_count: %d, group_mode: %s, operand_count: %d; NCCL support: "
3008         "%d",
3009         NcclThunkType::GetName(), replica_count, partition_count,
3010         CollectiveOpGroupModeToString(group_mode), op.operands().size(),
3011         NcclThunkType::NcclIsEnabled());
3012     if (!op.operands().empty()) {
3013       const Shape shape = GetShape(op.operands().front());
3014       absl::StrAppendFormat(&message, "; first operand array element-type: %s",
3015                             PrimitiveType_Name(shape.element_type()));
3016     }
3017     return Unimplemented("%s", message);
3018   }
3019 
3020   // All-gather with one replica is simply the identity function. Buffer
3021   // assignment expects a copy, so that's what we do.
3022   ThunkSequence thunks;
3023   for (int64_t i = 0; i < buffers.size(); i++) {
3024     const Shape shape = GetShape(op.operands()[i]);
3025     thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
3026         buffers.size() == 1 ? GetThunkInfo(op) : Thunk::ThunkInfo(),
3027         /*source_address=*/buffers[i].source_buffer,
3028         /*destination_buffer=*/buffers[i].destination_buffer,
3029         /*mem_size=*/ShapeUtil::ByteSizeOf(shape)));
3030   }
3031   if (thunks.size() == 1) {
3032     AddThunkToThunkSequence(std::move(thunks[0]));
3033   } else {
3034     AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
3035         GetThunkInfo(op), std::move(thunks)));
3036   }
3037   return Status::OK();
3038 }
3039 
EmitAllReduceDone(mlir::Operation * op)3040 Status IrEmitterUnnested::EmitAllReduceDone(mlir::Operation* op) {
3041   auto done_op = mlir::cast<mlir::lmhlo_gpu::AllReduceDoneOp>(op);
3042   auto start_op =
3043       done_op.token().getDefiningOp<mlir::lmhlo_gpu::AllReduceStartOp>();
3044   auto it = all_reduce_start_thunks_.find(start_op);
3045   TF_RET_CHECK(it != all_reduce_start_thunks_.end())
3046       << "couldn't find thunk for all-reduce-start op";
3047 
3048   // Can be null if no all-reduce-start thunk was created (e.g. if the start op
3049   // is degenerate), in which case there's nothing to do here.
3050   if (it->second != nullptr) {
3051     AddThunkToThunkSequence(absl::make_unique<NcclAllReduceDoneThunk>(
3052         GetThunkInfo(op), *it->second));
3053   }
3054   all_reduce_start_thunks_.erase(it);
3055   return Status::OK();
3056 }
3057 
EmitInfeed(mlir::Operation * op)3058 Status IrEmitterUnnested::EmitInfeed(mlir::Operation* op) {
3059   auto infeed_op = mlir::cast<mlir::lmhlo::InfeedOp>(op);
3060 
3061   std::vector<ShapedSlice> dest_slices;
3062   dest_slices.reserve(infeed_op.outputs().size());
3063 
3064   for (mlir::Value output : infeed_op.outputs()) {
3065     TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(output));
3066     const Shape& shape = GetShape(output);
3067     dest_slices.push_back(ShapedSlice{slice, shape});
3068   }
3069 
3070   AddThunkToThunkSequence(
3071       absl::make_unique<InfeedThunk>(GetThunkInfo(op), std::move(dest_slices)));
3072   return Status::OK();
3073 }
3074 
EmitOutfeed(mlir::Operation * op)3075 Status IrEmitterUnnested::EmitOutfeed(mlir::Operation* op) {
3076   auto outfeed_op = mlir::cast<mlir::lmhlo::OutfeedOp>(op);
3077 
3078   std::vector<ShapedSlice> source_slices;
3079   source_slices.reserve(outfeed_op.operands().size());
3080 
3081   for (mlir::Value operand : outfeed_op.operands()) {
3082     TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(operand));
3083     const Shape& shape = GetShape(operand);
3084     source_slices.push_back(ShapedSlice{slice, shape});
3085   }
3086 
3087   AddThunkToThunkSequence(absl::make_unique<OutfeedThunk>(
3088       GetThunkInfo(op), std::move(source_slices)));
3089   return Status::OK();
3090 }
3091 
BuildKernelThunkImpl(absl::string_view name,Thunk::ThunkInfo thunk_info,absl::Span<const BufferSlice> slices,std::vector<llvm_ir::IrArray> * ir_arrays,const LaunchDimensions & launch_dimensions)3092 std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunkImpl(
3093     absl::string_view name, Thunk::ThunkInfo thunk_info,
3094     absl::Span<const BufferSlice> slices,
3095     std::vector<llvm_ir::IrArray>* ir_arrays,
3096     const LaunchDimensions& launch_dimensions) {
3097   // Figure out which buffer allocations need to be passed as arguments to our
3098   // kernel.  This is simply all of the allocations referenced in slices,
3099   // plus the XLA temp buffer (if we have it).  We always include the temp
3100   // buffer because even if the kernel itself doesn't use it, a nested
3101   // subcomputation within the kernel (e.g. a kMap's computation) might.
3102   std::unordered_set<const BufferAllocation*> buffers_needed;
3103   for (const auto& slice : slices) {
3104     buffers_needed.insert(slice.buffer_slice.allocation());
3105   }
3106   absl::optional<const BufferAllocation*> temp_buffer;
3107   for (const BufferAllocation& alloc : ir_emitter_context_->allocations()) {
3108     if (alloc.IsPreallocatedTempBuffer()) {
3109       if (!temp_buffer.has_value()) {
3110         // Retrieve the first seen temp buffer.
3111         temp_buffer = &alloc;
3112       }
3113     }
3114   }
3115   if (temp_buffer.has_value()) {
3116     buffers_needed.insert(*temp_buffer);
3117   }
3118 
3119   // We'll pass a pointer to each of the elements of `buffers` to our kernel, in
3120   // this order.
3121   std::vector<const BufferAllocation*> non_constant_buffers;
3122   absl::c_copy_if(buffers_needed, std::back_inserter(non_constant_buffers),
3123                   [](const BufferAllocation* allocation) {
3124                     return !allocation->is_constant();
3125                   });
3126 
3127   absl::c_sort(non_constant_buffers,
3128                [](const BufferAllocation* a, const BufferAllocation* b) {
3129                  return a->index() < b->index();
3130                });
3131 
3132   llvm::Function* kernel = BuildKernelPrototype(name, non_constant_buffers);
3133 
3134   // Build a map from a BufferAllocation to the corresponding argument in our
3135   // kernel.
3136   std::unordered_map<const BufferAllocation*, llvm::Value*> kernel_args;
3137   {
3138     auto arg_it = kernel->arg_begin();
3139     auto buffers_it = non_constant_buffers.begin();
3140     for (; arg_it != kernel->arg_end(); ++arg_it, ++buffers_it) {
3141       kernel_args[*buffers_it] = arg_it;
3142 
3143       // Annotate all allocations with LLVM's `noalias`.
3144       // There are three kinds of allocations:
3145       // * Read-only allocations, aka input parameters that are not aliased with
3146       // outputs.
3147       // * Read-write allocations, including all output buffers, some of which
3148       // may alias with input HLO parameters, but aliased HLO buffers are always
3149       // assigned with the same allocation.
3150       // * The temp buffer.
3151       //
3152       // Read-only allocations may overlap with each other, but since they are
3153       // not mutated, they can always be annotated with `noalias` per LLVM
3154       // semantics.
3155       //
3156       // Read-write allocations and the temp buffer don't overlap with any
3157       // allocations, therefore they can also be annotated with `noalias`.
3158       kernel->addParamAttr(
3159           arg_it->getArgNo(),
3160           llvm::Attribute::get(arg_it->getContext(), llvm::Attribute::NoAlias));
3161     }
3162   }
3163 
3164   absl::flat_hash_set<BufferAllocation::Slice> buffers_written;
3165   for (const auto& slice : slices) {
3166     if (slice.written) {
3167       buffers_written.insert(slice.buffer_slice);
3168     }
3169   }
3170 
3171   ir_arrays->clear();
3172 
3173   // For each buffer our kernel might want to touch, bind it to a value derived
3174   // from our kernel args.
3175   for (const auto& slice : slices) {
3176     const BufferAllocation::Slice& buffer_slice = slice.buffer_slice;
3177 
3178     llvm::Value* loc;
3179     if (!slice.constant_name.empty()) {
3180       loc = ir_emitter_context_->llvm_module()->getGlobalVariable(
3181           slice.constant_name);
3182       CHECK_NE(loc, nullptr);
3183     } else {
3184       CHECK(!buffer_slice.allocation()->is_constant());
3185       loc = InBoundsGEP(kernel_args.at(buffer_slice.allocation()),
3186                         {b_.getInt64(buffer_slice.offset())});
3187     }
3188 
3189     llvm_ir::IrArray ir_array(CastToTypedValue(slice.shape, loc, &b_),
3190                               slice.shape);
3191     if (!buffers_written.contains(slice.buffer_slice)) {
3192       ir_array.MarkInvariantOverWholeProgram(&loc->getContext());
3193     }
3194 
3195     ir_arrays->push_back(ir_array);
3196   }
3197 
3198   AnnotateThunkLaunchDimensions(launch_dimensions,
3199                                 std::string(kernel->getName()),
3200                                 ir_emitter_context_->llvm_module());
3201   return absl::make_unique<KernelThunk>(thunk_info, non_constant_buffers,
3202                                         std::string(kernel->getName()),
3203                                         launch_dimensions);
3204 }
3205 
BuildKernelThunk(mlir::Operation * op,mlir::ValueRange operands,Thunk::ThunkInfo thunk_info,std::vector<llvm_ir::IrArray> * ir_arrays,const LaunchDimensions & launch_dimensions)3206 StatusOr<std::unique_ptr<KernelThunk>> IrEmitterUnnested::BuildKernelThunk(
3207     mlir::Operation* op, mlir::ValueRange operands, Thunk::ThunkInfo thunk_info,
3208     std::vector<llvm_ir::IrArray>* ir_arrays,
3209     const LaunchDimensions& launch_dimensions) {
3210   TF_RET_CHECK(!mlir::isa<mlir::lmhlo::FusionOp>(op));
3211 
3212   std::vector<BufferSlice> slices;
3213   for (mlir::Value operand : operands) {
3214     slices.emplace_back();
3215     auto& slice = slices.back();
3216     TF_ASSIGN_OR_RETURN(slice.buffer_slice,
3217                         GetAllocationSlice(operand, &slice.constant_name));
3218     slice.written = WritesMlirBuffer(op, operand);
3219     slice.shape = GetShape(operand);
3220   }
3221   std::string name = mlir::GetNameFromLoc(op->getLoc());
3222   return BuildKernelThunkImpl(name, thunk_info, slices, ir_arrays,
3223                               launch_dimensions);
3224 }
3225 
BuildKernelThunk(mlir::Operation * op,Thunk::ThunkInfo thunk_info,std::vector<llvm_ir::IrArray> * ir_arrays,const LaunchDimensions & launch_dimensions)3226 StatusOr<std::unique_ptr<KernelThunk>> IrEmitterUnnested::BuildKernelThunk(
3227     mlir::Operation* op, Thunk::ThunkInfo thunk_info,
3228     std::vector<llvm_ir::IrArray>* ir_arrays,
3229     const LaunchDimensions& launch_dimensions) {
3230   if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
3231     auto operands = GetHloOperands(op);
3232     auto outputs = GetHloOutputs(op);
3233 
3234     std::vector<BufferSlice> slices;
3235     for (auto operand : operands) {
3236       slices.emplace_back();
3237       auto& slice = slices.back();
3238       TF_ASSIGN_OR_RETURN(slice.buffer_slice,
3239                           GetAllocationSlice(operand, &slice.constant_name));
3240       slice.written = false;
3241       slice.shape = GetShape(operand);
3242     }
3243     for (auto output : outputs) {
3244       slices.emplace_back();
3245       auto& slice = slices.back();
3246       TF_ASSIGN_OR_RETURN(slice.buffer_slice,
3247                           GetAllocationSlice(output, &slice.constant_name));
3248       slice.written = true;
3249       slice.shape = GetShape(output);
3250     }
3251     std::string name = mlir::GetNameFromLoc(op->getLoc());
3252     return BuildKernelThunkImpl(name, thunk_info, slices, ir_arrays,
3253                                 launch_dimensions);
3254   }
3255   return BuildKernelThunk(op, op->getOperands(), thunk_info, ir_arrays,
3256                           launch_dimensions);
3257 }
3258 
BuildConstantInitializerThunk(absl::Span<const uint8> init_value,const BufferAllocation::Slice & dest,const Shape & output_shape)3259 std::unique_ptr<Thunk> IrEmitterUnnested::BuildConstantInitializerThunk(
3260     absl::Span<const uint8> init_value, const BufferAllocation::Slice& dest,
3261     const Shape& output_shape) {
3262   int64_t num_bytes = init_value.size();
3263   if (absl::c_all_of(init_value, [](uint8 byte) { return byte == 0; })) {
3264     return absl::make_unique<MemzeroThunk>(Thunk::ThunkInfo(), dest);
3265   }
3266 
3267   // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by
3268   // repeating the literal 4 or 2 times, so long as the destination buffer is
3269   // an even multiple of 32 bits long.
3270   if ((num_bytes == 1 || num_bytes == 2) &&
3271       ShapeUtil::ByteSizeOf(output_shape) % 4 == 0) {
3272     uint16 pattern16;
3273     if (num_bytes == 1) {
3274       uint8 b = init_value.front();
3275       pattern16 = uint16{b} | (uint16{b} << 8);
3276     } else {
3277       memcpy(&pattern16, init_value.data(), sizeof(pattern16));
3278     }
3279     uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16);
3280     return absl::make_unique<Memset32BitValueThunk>(Thunk::ThunkInfo(),
3281                                                     pattern32, dest);
3282   }
3283 
3284   // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit
3285   // memset so long as all 32-bit words of the scalar are equal to each other.
3286   if (num_bytes >= 4 && num_bytes % 4 == 0 &&
3287       memcmp(init_value.data(), init_value.data() + 4, init_value.size() - 4) ==
3288           0) {
3289     uint32 word;
3290     memcpy(&word, init_value.data(), sizeof(word));
3291     return absl::make_unique<Memset32BitValueThunk>(Thunk::ThunkInfo(), word,
3292                                                     dest);
3293   }
3294 
3295   return nullptr;
3296 }
3297 
3298 StatusOr<std::unique_ptr<Thunk>>
TryBuildConstantInitializerThunk(mlir::Value init_value,mlir::Value dest)3299 IrEmitterUnnested::TryBuildConstantInitializerThunk(mlir::Value init_value,
3300                                                     mlir::Value dest) {
3301   mlir::DenseElementsAttr const_init;
3302   if (auto get_global_memref =
3303           mlir::dyn_cast_or_null<mlir::memref::GetGlobalOp>(
3304               init_value.getDefiningOp())) {
3305     auto global_memref =
3306         mlir::SymbolTable::lookupNearestSymbolFrom<mlir::memref::GlobalOp>(
3307             get_global_memref, get_global_memref.name());
3308     if (global_memref.constant() && global_memref.initial_value()) {
3309       // If the initial value happens to be a constant, generate a specialized
3310       // thunk.
3311       const_init = global_memref.initial_value()
3312                        .getValue()
3313                        .cast<mlir::DenseElementsAttr>();
3314     }
3315   } else if (auto constant = mlir::dyn_cast_or_null<mlir::mhlo::ConstOp>(
3316                  init_value.getDefiningOp())) {
3317     const_init = constant.value().dyn_cast<mlir::DenseElementsAttr>();
3318   }
3319 
3320   if (const_init) {
3321     std::vector<uint8> literal_bytes;
3322     TF_RETURN_IF_ERROR(
3323         CopyDenseElementsDataToXlaFormat(const_init, &literal_bytes));
3324 
3325     TF_ASSIGN_OR_RETURN(auto dest_slice, GetAllocationSlice(dest));
3326 
3327     const Shape dest_shape = GetShape(dest);
3328     auto thunk =
3329         BuildConstantInitializerThunk(literal_bytes, dest_slice, dest_shape);
3330     if (thunk) {
3331       return {std::move(thunk)};
3332     }
3333   }
3334   return std::unique_ptr<Thunk>();
3335 }
3336 
BuildInitializerThunk(mlir::Operation * op,mlir::Value init_value,mlir::Value dest)3337 StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
3338     mlir::Operation* op, mlir::Value init_value, mlir::Value dest) {
3339   // initial value must be a scalar memref.
3340   auto init_type = init_value.getType().dyn_cast<mlir::MemRefType>();
3341   TF_RET_CHECK(init_type.getRank() == 0);
3342 
3343   TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> constant_init_thunk,
3344                       TryBuildConstantInitializerThunk(init_value, dest));
3345   if (constant_init_thunk) {
3346     return {std::move(constant_init_thunk)};
3347   }
3348 
3349   // Otherwise fall back to our slow initializer code. The thunk in this case
3350   // will just need the IR arrays for the initial value and the destination.
3351   const Shape dest_shape = GetShape(dest);
3352   TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions,
3353                       CalculateLaunchDimensions(
3354                           dest_shape, ir_emitter_context_->gpu_device_info()));
3355   std::vector<llvm_ir::IrArray> ir_arrays;
3356   TF_ASSIGN_OR_RETURN(
3357       std::unique_ptr<KernelThunk> kernel_thunk,
3358       BuildKernelThunk(op, {init_value, dest}, Thunk::ThunkInfo(), &ir_arrays,
3359                        launch_dimensions));
3360 
3361   const llvm_ir::IrArray init_array = ir_arrays[0];
3362   const llvm_ir::IrArray dest_array = ir_arrays[1];
3363 
3364   std::string name = mlir::GetNameFromLoc(op->getLoc());
3365   TF_RETURN_IF_ERROR(ParallelLoopEmitter(
3366                          [=](const IrArray::Index& index) {
3367                            return init_array.EmitReadArrayElement(index, &b_);
3368                          },
3369                          dest_array, launch_dimensions, &b_)
3370                          .EmitLoop(mlir::GetNameFromLoc(op->getLoc())));
3371 
3372   // Convert unique_ptr<KernelThunk> to StatusOr<unique_ptr<Thunk>>.
3373   return {std::move(kernel_thunk)};
3374 }
3375 
BuildFusedInitializerThunk(mlir::lmhlo::FusionOp fusion,int output_index)3376 StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildFusedInitializerThunk(
3377     mlir::lmhlo::FusionOp fusion, int output_index) {
3378   auto reduce = mlir::dyn_cast_or_null<mlir::mhlo::ReduceOp>(
3379       fusion.getFusionResults()[output_index].getDefiningOp());
3380 
3381   TF_RET_CHECK(reduce);
3382   TF_RET_CHECK(reduce.getNumResults() == 1);
3383 
3384   mlir::Value init_value = reduce.init_values()[0];
3385   mlir::Value dest = fusion.getOutputBuffers()[output_index];
3386   TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> constant_init_thunk,
3387                       TryBuildConstantInitializerThunk(init_value, dest));
3388   if (constant_init_thunk) {
3389     return {std::move(constant_init_thunk)};
3390   }
3391 
3392   auto input_buffers = fusion.getInputBuffers();
3393 
3394   const Shape dest_shape = GetShape(dest);
3395   TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions,
3396                       CalculateLaunchDimensions(
3397                           dest_shape, ir_emitter_context_->gpu_device_info()));
3398   std::vector<llvm_ir::IrArray> ir_arrays;
3399   TF_ASSIGN_OR_RETURN(std::unique_ptr<KernelThunk> kernel_thunk,
3400                       BuildKernelThunk(fusion, Thunk::ThunkInfo(), &ir_arrays,
3401                                        launch_dimensions));
3402 
3403   const llvm_ir::IrArray dest_array =
3404       ir_arrays[input_buffers.size() + output_index];
3405 
3406   const HloComputation* fused_computation =
3407       *GetOrCreateSubComputationFromRegion(&fusion.region(),
3408                                            /*is_fusion=*/true);
3409 
3410   // If init_value was fused into this reduce we have to generate it first.
3411   GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
3412                                           ir_emitter_context_->llvm_module(),
3413                                           &b_, GetNestedComputer());
3414 
3415   FusedIrEmitter fused_emitter(&elemental_emitter);
3416   for (int i = 0; i < fused_computation->num_parameters(); i++) {
3417     fused_emitter.BindGenerator(
3418         fused_computation->parameter_instruction(i),
3419         [this, &ir_arrays, i](llvm_ir::IrArray::Index index) {
3420           return ir_arrays[i].EmitReadArrayElement(index, &b_);
3421         });
3422   }
3423   HloInstruction* instr = fused_computation->root_instruction();
3424   if (instr->opcode() != HloOpcode::kTuple) {
3425     CHECK_EQ(0, output_index);
3426   } else {
3427     instr = instr->mutable_operand(output_index);
3428   }
3429   TF_RET_CHECK(instr->shape().IsArray());
3430   TF_ASSIGN_OR_RETURN(auto generator,
3431                       fused_emitter.GetGenerator(instr->operand(1)));
3432   TF_RETURN_IF_ERROR(
3433       ParallelLoopEmitter(generator, dest_array, launch_dimensions, &b_)
3434           .EmitLoop(mlir::GetNameFromLoc(fusion.getLoc())));
3435   return {std::move(kernel_thunk)};
3436 }
3437 
BuildWhileThunk(mlir::lmhlo::WhileOp while_op,const Thunk::ThunkInfo & thunk_info)3438 StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildWhileThunk(
3439     mlir::lmhlo::WhileOp while_op, const Thunk::ThunkInfo& thunk_info) {
3440   // Generate thunk sequence for while 'condition'.
3441   mlir::Region* condition = &while_op.cond();
3442   TF_ASSIGN_OR_RETURN(
3443       auto ir_emitter_condition,
3444       IrEmitterUnnested::Create(hlo_module_config_, ir_emitter_context_));
3445 
3446   TF_RETURN_IF_ERROR(ir_emitter_condition->EmitLmhloRegion(condition));
3447 
3448   // Generate thunk sequence for while 'body'.
3449   mlir::Region* body = &while_op.body();
3450   TF_ASSIGN_OR_RETURN(
3451       auto ir_emitter_body,
3452       IrEmitterUnnested::Create(hlo_module_config_, ir_emitter_context_));
3453 
3454   TF_RETURN_IF_ERROR(ir_emitter_body->EmitLmhloRegion(body));
3455 
3456   // Extract the condition value from the last op (exlucidng the terminator op)
3457   // in the condition region.
3458   auto cond_result = GetHloOutputs(while_op);
3459   TF_RET_CHECK(cond_result.size() == 1);
3460   TF_ASSIGN_OR_RETURN(auto cond_result_slice,
3461                       GetAllocationSlice(cond_result[0]));
3462 
3463   return std::unique_ptr<Thunk>(
3464       new WhileThunk(thunk_info, cond_result_slice,
3465                      ir_emitter_condition->ConsumeThunkSequence(),
3466                      ir_emitter_body->ConsumeThunkSequence()));
3467 }
3468 
BuildForThunk(mlir::lmhlo::WhileOp while_op,const Thunk::ThunkInfo & thunk_info,const int64_t loop_limit)3469 StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildForThunk(
3470     mlir::lmhlo::WhileOp while_op, const Thunk::ThunkInfo& thunk_info,
3471     const int64_t loop_limit) {
3472   // Generate thunk sequence for while 'body' (will be used a For loop body).
3473   TF_ASSIGN_OR_RETURN(
3474       auto ir_emitter_body,
3475       IrEmitterUnnested::Create(hlo_module_config_, ir_emitter_context_));
3476   TF_RETURN_IF_ERROR(ir_emitter_body->EmitLmhloRegion(&while_op.body()));
3477 
3478   return std::unique_ptr<Thunk>(new ForThunk(
3479       thunk_info, loop_limit, ir_emitter_body->ConsumeThunkSequence()));
3480 }
3481 
EmitTargetElementLoop(const HloInstruction & hlo,const llvm_ir::ElementGenerator & body_emitter)3482 Status IrEmitterUnnested::EmitTargetElementLoop(
3483     const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter) {
3484   return InternalError("This should be unreachable");
3485 }
3486 
3487 // Gets the output offset as calculated from thread_id.x (to be applied to the
3488 // offset calculated from block_id and thread_id.y).
GetStartOffsetX(const KernelMappingScheme & mapping_scheme,llvm::Value * thread_id_x,llvm::Type * index_ty,llvm::IRBuilder<> * b)3489 static llvm::Value* GetStartOffsetX(const KernelMappingScheme& mapping_scheme,
3490                                     llvm::Value* thread_id_x,
3491                                     llvm::Type* index_ty,
3492                                     llvm::IRBuilder<>* b) {
3493   auto constant = [&](int64_t val) {
3494     return llvm::ConstantInt::get(index_ty, val);
3495   };
3496   if (mapping_scheme.GetIndexingOrder() == kStridedIndexingX) {
3497     return thread_id_x;
3498   } else if (mapping_scheme.GetIndexingOrder() == kStridedLinearIndexingX) {
3499     return b->CreateMul(thread_id_x, constant(mapping_scheme.GetVectorSize()));
3500   }
3501   CHECK_EQ(mapping_scheme.GetIndexingOrder(), kLinearIndexingX);
3502   int64_t x_num_steps =
3503       mapping_scheme.GetTileSizeX() / mapping_scheme.GetNumThreadsX();
3504   return b->CreateMul(thread_id_x, constant(x_num_steps));
3505 }
3506 
3507 // Calls `emit_elem_function()` `x_num_steps` times.  If
3508 // `vector_size`==1, then each element index passed to
3509 // `emit_elem_function()` will be separated by `step_x`. If `vector_size`>1,
3510 // then it must be a multiple of `x_num_steps`.  In that case, it
3511 // triggers a different indexing order that is vectorizable by
3512 // LLVM. It generates many groups of calls to `emit_elem_function`. Each
3513 // group is separated by `step_x` elements.  Inside a group, elements
3514 // are consecutive. If `check_x_tile_bounds` is true, then it will check
3515 // if the element index is in bound compared to `tile_width` before
3516 // calling `emit_elem_function`.
UnrollInnerTileLoop(bool check_x_tile_bounds,int64_t x_num_steps,int64_t step_x,int64_t vector_size,const string & loop_name,KernelSupportLibrary * ksl,llvm::Value * start_offset_x,llvm::Value * y_loc,llvm::Value * tile_width,const IrArray::Index & source_idx,llvm::IRBuilder<> * b,const IrEmitterUnnested::EmitElementFunction * emit_elem_function)3517 static void UnrollInnerTileLoop(
3518     bool check_x_tile_bounds, int64_t x_num_steps, int64_t step_x,
3519     int64_t vector_size, const string& loop_name, KernelSupportLibrary* ksl,
3520     llvm::Value* start_offset_x, llvm::Value* y_loc, llvm::Value* tile_width,
3521     const IrArray::Index& source_idx, llvm::IRBuilder<>* b,
3522     const IrEmitterUnnested::EmitElementFunction* emit_elem_function) {
3523   llvm::Type* index_ty = tile_width->getType();
3524   auto constant = [&](int64_t val) {
3525     return llvm::ConstantInt::get(index_ty, val);
3526   };
3527   IrArray::Index source_idx_x_base = source_idx.AddOffsetToDim(y_loc, kDimY, b);
3528   for (int64_t j = 0; j < x_num_steps / vector_size; j++) {
3529     for (int64_t i = 0; i < vector_size; i++) {
3530       int64_t linear_index = j * vector_size + i;
3531       llvm::Value* x_loc = b->CreateAdd(constant(j * step_x * vector_size + i),
3532                                         start_offset_x, "x_loc");
3533       IrArray::Index source_idx_x = source_idx_x_base.AddOffsetToDim(
3534           constant(j * step_x * vector_size + i), kDimX, b);
3535       auto emit_element = [&] {
3536         return (*emit_elem_function)(source_idx_x, y_loc, x_loc, linear_index);
3537       };
3538       if (check_x_tile_bounds) {
3539         ksl->If(loop_name + "_x_in_tile", b->CreateICmpULT(x_loc, tile_width),
3540                 emit_element);
3541       } else {
3542         emit_element();
3543       }
3544     }
3545   }
3546 }
3547 
EmitTile(const KernelMappingScheme & mapping_scheme,const IrArray::Index & tile_origin_index,const string & loop_name,KernelSupportLibrary * ksl,const ThreadIdInfo & thread_id_info,llvm::Value * tile_height,llvm::Value * tile_width,const IrEmitterUnnested::EmitElementFunction & emit_elem_function)3548 void IrEmitterUnnested::EmitTile(
3549     const KernelMappingScheme& mapping_scheme,
3550     const IrArray::Index& tile_origin_index, const string& loop_name,
3551     KernelSupportLibrary* ksl, const ThreadIdInfo& thread_id_info,
3552     llvm::Value* tile_height, llvm::Value* tile_width,
3553     const IrEmitterUnnested::EmitElementFunction& emit_elem_function) {
3554   llvm::Type* index_ty = tile_width->getType();
3555   auto constant = [&](int64_t val) {
3556     return llvm::ConstantInt::get(index_ty, val);
3557   };
3558   int64_t num_threads_x = mapping_scheme.GetNumThreadsX();
3559   llvm::Value* num_threads_y = constant(mapping_scheme.GetNumThreadsY());
3560   int64_t tile_size_x = mapping_scheme.GetTileSizeX();
3561 
3562   int64_t x_num_steps = tile_size_x / num_threads_x;
3563   llvm::Value* start_offset_x = GetStartOffsetX(
3564       mapping_scheme, thread_id_info.thread_id_x, index_ty, &b_);
3565 
3566   // Using dilated mapping scheme, each thread steps with a stride of number
3567   // of threads.
3568   // Otherwise, the stride is one, but we multiply each offset by the limit of
3569   // number of steps which can be made.
3570   int64_t step_x =
3571       mapping_scheme.GetIndexingOrder() == kLinearIndexingX ? 1 : num_threads_x;
3572   int64_t vector_size = mapping_scheme.GetVectorSize();
3573 
3574   IrArray::Index source_idx =
3575       tile_origin_index.AddOffsetToDim(start_offset_x, kDimX, &b_);
3576 
3577   // True iff all threads always execute all instructions in the tiling
3578   // dimension X.
3579   bool x_tile_fits =
3580       mapping_scheme.GetDimsInElems()[kDimX] % tile_size_x == 0 &&
3581       mapping_scheme.GetRowContiguous();
3582 
3583   ksl->For(
3584       loop_name + "_y_in_tile",
3585       /*start=*/thread_id_info.thread_id_y,
3586       /*end=*/
3587       tile_height,
3588       /*step=*/num_threads_y, [&](llvm::Value* y_loc) {
3589         auto unroll_inner_tile_loop = [&](bool check_x_tile_bounds) {
3590           return UnrollInnerTileLoop(check_x_tile_bounds, x_num_steps, step_x,
3591                                      vector_size, loop_name, ksl,
3592                                      start_offset_x, y_loc, tile_width,
3593                                      source_idx, &b_, &emit_elem_function);
3594         };
3595 
3596         // Only take this path when we unroll in a way vectorizable by
3597         // LLVM. Special case when the tile doesn't fit completely for even
3598         // row size. For odd row size every other row isn't aligned to the
3599         // vectorized size, so it can't be vectorized by LLVM.
3600         if (!x_tile_fits &&
3601             mapping_scheme.GetIndexingOrder() == kStridedLinearIndexingX) {
3602           ksl->If(
3603               loop_name + "_is_full_tile",
3604               // For the last block, tile_width will be the number of
3605               // elements left.
3606               b_.CreateICmpEQ(constant(mapping_scheme.GetTileSizeX()),
3607                               tile_width),
3608               [&] { unroll_inner_tile_loop(/*check_x_tile_bounds=*/false); },
3609               [&] { unroll_inner_tile_loop(/*check_x_tile_bounds=*/true); });
3610         } else {
3611           unroll_inner_tile_loop(/*check_x_tile_bounds=*/!x_tile_fits);
3612         }
3613       });
3614 }
3615 
GetUnnormalizedIndex(const IrArray::Index & normalized_shape_index,const Shape & unnormalized_shape,llvm::IRBuilder<> * b_,const KernelMappingScheme & kernel_mapping_scheme)3616 static IrArray::Index GetUnnormalizedIndex(
3617     const IrArray::Index& normalized_shape_index,
3618     const Shape& unnormalized_shape, llvm::IRBuilder<>* b_,
3619     const KernelMappingScheme& kernel_mapping_scheme) {
3620   DCHECK_EQ(normalized_shape_index.size(), 3);
3621   // If the normalization only add a new dimensions of size 1,
3622   // generate simpler indexing. LLVM doesn't always simplify the more
3623   // complicated indexing and this prevents it from vectorizing some
3624   // cases. We do this only for major_to_minor memory layout.
3625   if (unnormalized_shape.rank() == 2 && unnormalized_shape.has_layout() &&
3626       unnormalized_shape.dimensions()[0] == normalized_shape_index.dims()[1] &&
3627       unnormalized_shape.dimensions()[1] == normalized_shape_index.dims()[2] &&
3628       unnormalized_shape.layout().minor_to_major(1) == 0) {
3629     CHECK_EQ(normalized_shape_index.dims()[0], 1);
3630     auto multidim = normalized_shape_index.multidim();
3631     return IrArray::Index({multidim[1], multidim[2]}, unnormalized_shape,
3632                           normalized_shape_index.GetType());
3633   }
3634   if (unnormalized_shape.rank() == 2 && unnormalized_shape.has_layout() &&
3635       unnormalized_shape.dimensions()[0] == normalized_shape_index.dims()[2] &&
3636       unnormalized_shape.dimensions()[1] == normalized_shape_index.dims()[1] &&
3637       unnormalized_shape.layout().minor_to_major(1) == 1) {
3638     CHECK_EQ(normalized_shape_index.dims()[0], 1);
3639     auto multidim = normalized_shape_index.multidim();
3640     return IrArray::Index({multidim[2], multidim[1]}, unnormalized_shape,
3641                           normalized_shape_index.GetType());
3642   }
3643   llvm::Value* linear = normalized_shape_index.Linearize(
3644       kernel_mapping_scheme.GetDimsInElems(), b_);
3645   return IrArray::Index(linear, unnormalized_shape, b_);
3646 }
3647 
3648 // Emits code to process a tensor element in a tile for the given kLoop fusion
3649 // HLO containing parameters that are 0-2-1 transpose of its outputs.
3650 //
3651 // index: The index for the first output element in the normalized tensor, that
3652 //   is the resulting tensor after collapsing contiguous dimensions that play
3653 //   the same role in the transpose.
3654 // kernel_info: Other information to support the kernel code generation.
EmitTileElementForFusion(mlir::lmhlo::FusionOp fusion,absl::Span<const llvm_ir::IrArray> operand_arrays,absl::Span<const llvm_ir::IrArray> output_arrays,const llvm_ir::IrArray::Index & index,const KernelMappingScheme & mapping_scheme,llvm::Value * y_loc,llvm::Value * x_loc,absl::Span<llvm::Value * const> param_shmem_buffers)3655 void IrEmitterUnnested::EmitTileElementForFusion(
3656     mlir::lmhlo::FusionOp fusion,
3657     absl::Span<const llvm_ir::IrArray> operand_arrays,
3658     absl::Span<const llvm_ir::IrArray> output_arrays,
3659     const llvm_ir::IrArray::Index& index,
3660     const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
3661     llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers) {
3662   const HloComputation* fused_computation =
3663       *GetOrCreateSubComputationFromRegion(&fusion.region(),
3664                                            /*is_fusion=*/true);
3665   GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
3666                                      GetNestedComputer());
3667   FusedIrEmitter fused_emitter(&elem_emitter);
3668   for (int i = 0; i < operand_arrays.size(); i++) {
3669     llvm_ir::ElementGenerator gen;
3670     if (llvm::Value* param_tile_buffer = param_shmem_buffers[i]) {
3671       gen = [this, param_tile_buffer, x_loc,
3672              y_loc](llvm_ir::IrArray::Index index) {
3673         // TODO(jlebar): Add AA metadata to this load.  Tile buffers are
3674         // global variables, so LLVM's points-to analysis doesn't help us
3675         // much.  And we want the AA info to be present before address
3676         // spaces are inferred (which is pretty late in the pipeline), so
3677         // even if we had address-space-based AA in LLVM, it wouldn't help
3678         // us much here.
3679         return b_.CreateLoad(
3680             b_.CreateGEP(param_tile_buffer,
3681                          {index.GetConstantWithIndexType(0), x_loc, y_loc}),
3682             "tiled_buffer");
3683       };
3684     } else {
3685       auto array = operand_arrays[i];
3686       auto name = fused_computation->parameter_instruction(i)->name();
3687       gen = [this, array, name](const llvm_ir::IrArray::Index& index) {
3688         return array.EmitReadArrayElement(index, &b_, name);
3689       };
3690     }
3691     fused_emitter.BindGenerator(fused_computation->parameter_instruction(i),
3692                                 std::move(gen));
3693   }
3694   IrArray::Index untiled_index = GetUnnormalizedIndex(
3695       index, output_arrays[0].GetShape(), &b_, mapping_scheme);
3696   llvm_ir::ElementGenerator output_generator =
3697       *fused_emitter.GetGenerator(fused_computation->root_instruction());
3698   llvm::Value* output_value = output_generator(untiled_index).ValueOrDie();
3699   if (output_arrays.size() > 1) {
3700     DCHECK(output_value->getType()->isStructTy());
3701     DCHECK_EQ(output_value->getType()->getStructNumElements(),
3702               output_arrays.size());
3703     for (int64_t i = 0; i < output_arrays.size(); ++i) {
3704       output_arrays[i].EmitWriteArrayElement(
3705           untiled_index, ExtractValue(output_value, i), &b_);
3706     }
3707   } else {
3708     output_arrays[0].EmitWriteArrayElement(untiled_index, output_value, &b_);
3709   }
3710 }
3711 
GetReduceFromUnnestedMlir(mlir::Operation * unnested_hlo,int index)3712 static mlir::Operation* GetReduceFromUnnestedMlir(mlir::Operation* unnested_hlo,
3713                                                   int index) {
3714   auto fusion = mlir::cast<mlir::lmhlo::FusionOp>(unnested_hlo);
3715   auto results = fusion.getFusionResults();
3716   CHECK(index < results.size())
3717       << MlirToString(unnested_hlo) << " vs " << index;
3718   return results[index].getDefiningOp();
3719 }
3720 
EmitPrologueForReduction(mlir::Operation * unnested_hlo,absl::Span<const int> instr_index_group,HloComputation * fused_computation,FusedIrEmitter * fused_emitter,absl::Span<const llvm_ir::IrArray> result_ir_arrays,ReductionCodegenState * reduction_info,const FusionLayoutAnalysis & layout_analysis)3721 void IrEmitterUnnested::EmitPrologueForReduction(
3722     mlir::Operation* unnested_hlo, absl::Span<const int> instr_index_group,
3723     HloComputation* fused_computation, FusedIrEmitter* fused_emitter,
3724     absl::Span<const llvm_ir::IrArray> result_ir_arrays,
3725     ReductionCodegenState* reduction_info,
3726     const FusionLayoutAnalysis& layout_analysis) {
3727   VLOG(10) << "Emit prologue for reduction: " << MlirToString(unnested_hlo);
3728   mlir::Operation* first_reduce = nullptr;
3729   for (int index : instr_index_group) {
3730     mlir::Operation* reduce_inst =
3731         GetReduceFromUnnestedMlir(unnested_hlo, index);
3732 
3733     if (!IsReductionFromOrToContiguousDimensions(reduce_inst,
3734                                                  layout_analysis)) {
3735       continue;
3736     }
3737 
3738     auto results = GetHloOutputs(reduce_inst);
3739     CHECK_EQ(1, results.size());
3740     Shape reduce_inst_shape = layout_analysis.GetShape(results[0]);
3741 
3742     VLOG(10) << "Emit prologue for reduction: " << MlirToString(reduce_inst);
3743     if (first_reduce == nullptr) {
3744       first_reduce = reduce_inst;
3745     } else {
3746       CHECK(absl::c_equal(
3747           first_reduce->getAttrOfType<mlir::DenseIntElementsAttr>("dimensions"),
3748           reduce_inst->getAttrOfType<mlir::DenseIntElementsAttr>(
3749               "dimensions")));
3750     }
3751 
3752     AddressVector* reduction_input_addresses =
3753         reduction_info->GetMutableReductionInputAddresses();
3754     llvm::Type* element_type = llvm_ir::PrimitiveTypeToIrType(
3755         reduce_inst_shape.element_type(), ir_emitter_context_->llvm_module());
3756     llvm::AllocaInst* reduction_input_address =
3757         llvm_ir::EmitAllocaAtFunctionEntry(element_type,
3758                                            "reduction_input_address", &b_);
3759     reduction_input_addresses->push_back(reduction_input_address);
3760 
3761     int num_partial_results = reduction_info->GetNumPartialResults();
3762     AddressVector* partial_result_addresses =
3763         reduction_info->GetMutablePartialResultAddresses();
3764     llvm::AllocaInst* partial_result_address =
3765         llvm_ir::EmitAllocaAtFunctionEntryWithCount(
3766             element_type, /*element_count=*/b_.getInt32(num_partial_results),
3767             ("partial_reduction_result." + llvm::Twine(index)).str(), &b_);
3768     partial_result_addresses->push_back(partial_result_address);
3769 
3770     // Initialize the partial result with the initial value of the reduction.
3771     llvm::Value* init_ir_value;
3772     const HloInstruction* reduce_hlo = fused_computation->root_instruction();
3773     if (reduce_hlo->opcode() == HloOpcode::kTuple) {
3774       reduce_hlo = reduce_hlo->operand(index);
3775     }
3776     const HloInstruction* init_value = reduce_hlo->operand(1);
3777     init_ir_value = (*fused_emitter->GetGenerator(init_value))(
3778                         IrArray::Index(b_.getInt32Ty()))
3779                         .ValueOrDie();
3780 
3781     for (int i = 0; i < num_partial_results; ++i) {
3782       b_.CreateStore(init_ir_value, b_.CreateInBoundsGEP(partial_result_address,
3783                                                          {b_.getInt32(i)}));
3784     }
3785     reduction_info->GetMutableInitialValues()->push_back(init_ir_value);
3786 
3787     auto& mapping_scheme = reduction_info->GetKernelMappingScheme();
3788     int64_t num_threads_x = mapping_scheme.GetNumThreadsX();
3789     llvm::Type* primitive_type = llvm_ir::PrimitiveTypeToIrType(
3790         reduce_inst_shape.element_type(), module_);
3791     llvm::Type* buffer_type = [&] {
3792       if (reduction_info->IsRowReduction()) {
3793         // Allocate __shared__ cache[num_partial_results][kWarpSize].
3794         // TODO(cheshire): Do we need the same trick as below to avoid bank
3795         // conflicts?
3796         return llvm::ArrayType::get(
3797             llvm::ArrayType::get(primitive_type, kWarpSize),
3798             num_partial_results);
3799       } else {
3800         // Allocate __shared__
3801         // cache[num_partial_results][num_threads][num_threads + 1], where
3802         // num_threads == num_threads_x == num_threads_y.  The "+1" is used to
3803         // avoid bank conflicts.
3804         CHECK_EQ(num_threads_x, mapping_scheme.GetNumThreadsY());
3805         return llvm::ArrayType::get(
3806             llvm::ArrayType::get(
3807                 llvm::ArrayType::get(primitive_type, num_threads_x + 1),
3808                 num_threads_x),
3809             num_partial_results);
3810       }
3811     }();
3812     llvm::GlobalVariable* shared_cache_per_reduce =
3813         llvm_ir::AllocateSharedMemoryTile(b_.GetInsertBlock()->getModule(),
3814                                           buffer_type,
3815                                           absl::StrCat("shared_cache_", index));
3816     reduction_info->GetMutableSharedCache()->push_back(shared_cache_per_reduce);
3817   }
3818   CHECK(first_reduce);
3819 }
3820 
EmitFullWarpShuffleDownLoopForAllReduces(absl::Span<HloComputation * const> reducers,absl::Span<llvm::AllocaInst * const> partial_result_addresses,int threads_per_block)3821 void IrEmitterUnnested::EmitFullWarpShuffleDownLoopForAllReduces(
3822     absl::Span<HloComputation* const> reducers,
3823     absl::Span<llvm::AllocaInst* const> partial_result_addresses,
3824     int threads_per_block) {
3825   CHECK_EQ(reducers.size(), partial_result_addresses.size());
3826   for (int i = 0; i != reducers.size(); i++) {
3827     EmitFullWarpShuffleDownLoopForReduce(
3828         reducers[i], partial_result_addresses[i]->getType()->getElementType(),
3829         partial_result_addresses[i], threads_per_block);
3830   }
3831 }
3832 
EmitFullWarpShuffleDownLoopForReduce(HloComputation * reducer,llvm::Type * element_type,llvm::Value * partial_result_address,int threads_per_block)3833 void IrEmitterUnnested::EmitFullWarpShuffleDownLoopForReduce(
3834     HloComputation* reducer, llvm::Type* element_type,
3835     llvm::Value* partial_result_address, int threads_per_block) {
3836   // This only works when the block size is a multiple of 32 threads.
3837   CHECK_EQ(threads_per_block % 32, 0);
3838   for (int distance = 16; distance >= 1; distance /= 2) {
3839     int bit_width = llvm_ir::GetSizeInBits(element_type);
3840     llvm::Value* result_from_other_lane = llvm_ir::EmitAllocaAtFunctionEntry(
3841         element_type, "result_from_other_lane", &b_);
3842     // Bitcast cannot be applied to aggregate types (even packed ones), so
3843     // we bitcast addresses of load/store to intN* of the same bit-width.
3844     llvm::Type* shuffled_value_type =
3845         element_type->isStructTy() ? b_.getIntNTy(bit_width) : element_type;
3846     auto convert_pointer_for_shuffle = [&](llvm::Value* ptr) {
3847       return b_.CreatePointerBitCastOrAddrSpaceCast(
3848           ptr, shuffled_value_type->getPointerTo());
3849     };
3850     llvm::Value* partial_result =
3851         Load(convert_pointer_for_shuffle(partial_result_address),
3852              "partial_reduction_result");
3853     Store(EmitFullWarpShuffleDown(partial_result, b_.getInt32(distance), &b_),
3854           convert_pointer_for_shuffle(result_from_other_lane));
3855     TF_CHECK_OK(EmitCallToNestedComputation(
3856         *reducer, {partial_result_address, result_from_other_lane},
3857         partial_result_address));
3858   }
3859 }
3860 
3861 // Given the IrArray index of a reduction input, returns the linear address of
3862 // the reduction output as if the reduction were going to keep the input shape
3863 // with the dimensions being reduced moved.
GetUntransposedOutputLinearAddress(llvm::IRBuilder<> * b,const llvm_ir::IrArray::Index & index,const ReductionCodegenState & reduction_info)3864 static llvm::Value* GetUntransposedOutputLinearAddress(
3865     llvm::IRBuilder<>* b, const llvm_ir::IrArray::Index& index,
3866     const ReductionCodegenState& reduction_info) {
3867   const KernelMappingScheme& kernel_mapping_scheme =
3868       reduction_info.GetKernelMappingScheme();
3869   if (reduction_info.IsRowReduction()) {
3870     // For row-reduction, y-coordinate determines which row we write into.
3871     return index[kDimY];
3872   }
3873   // For column reduction, we get the transposed address.
3874   absl::Span<const int64> dims_in_elem = kernel_mapping_scheme.GetDimsInElems();
3875   llvm::Value* x_dim_size = index.GetConstantWithIndexType(dims_in_elem[kDimX]);
3876   llvm::Value* x_block_offset = b->CreateMul(index[kDimZ], x_dim_size);
3877   return b->CreateAdd(x_block_offset, index[kDimX]);
3878 }
3879 
EmitEpilogueForReduction(llvm::Type * index_ty,mlir::Operation * unnested_hlo,absl::Span<const int> instr_index_group,absl::Span<const llvm_ir::IrArray> result_ir_arrays,absl::Span<HloComputation * const> reducers,const ReductionCodegenState & reduction_info,const TilingKernelInfo & tiling_kernel_info,const FusionLayoutAnalysis & layout_analysis)3880 void IrEmitterUnnested::EmitEpilogueForReduction(
3881     llvm::Type* index_ty, mlir::Operation* unnested_hlo,
3882     absl::Span<const int> instr_index_group,
3883     absl::Span<const llvm_ir::IrArray> result_ir_arrays,
3884     absl::Span<HloComputation* const> reducers,
3885     const ReductionCodegenState& reduction_info,
3886     const TilingKernelInfo& tiling_kernel_info,
3887     const FusionLayoutAnalysis& layout_analysis) {
3888   const KernelMappingScheme& mapping_scheme =
3889       reduction_info.GetKernelMappingScheme();
3890   auto constant = [&](uint64 c) -> llvm::Constant* {
3891     return llvm::ConstantInt::get(index_ty, c);
3892   };
3893 
3894   IrEmitterUnnested::ThreadIdInfo thread_id_info =
3895       EmitThreadIdInfo(mapping_scheme.GetThreadsPerBlock(), index_ty,
3896                        mapping_scheme.GetNumThreadsX());
3897 
3898   IrArray::Index start_offset = [&] {
3899     llvm::Value* x_loc = thread_id_info.thread_id_x;
3900     llvm::Value* y_loc = thread_id_info.thread_id_y;
3901     if (!reduction_info.IsRowReduction()) {
3902       std::swap(x_loc, y_loc);
3903     }
3904     llvm::Value* start_offset_x =
3905         GetStartOffsetX(mapping_scheme, x_loc, index_ty, &b_);
3906     return tiling_kernel_info.tile_origin.AddOffsetToDim(y_loc, kDimY, &b_)
3907         .AddOffsetToDim(start_offset_x, kDimX, &b_);
3908   }();
3909 
3910   absl::Span<llvm::AllocaInst* const> partial_result_addresses =
3911       reduction_info.GetPartialResultAddresses();
3912 
3913   int reduction_idx = -1;
3914 
3915   // `instruction_idx` is indexing over all instructions in a group, some of
3916   // them might not be unnested reductions.
3917   for (int instruction_idx : instr_index_group) {
3918     mlir::Operation* reduce_hlo =
3919         GetReduceFromUnnestedMlir(unnested_hlo, instruction_idx);
3920     llvm_ir::IrArray output_array = result_ir_arrays[instruction_idx];
3921     if (!IsReductionFromOrToContiguousDimensions(reduce_hlo, layout_analysis)) {
3922       continue;
3923     }
3924     reduction_idx++;
3925 
3926     Shape operand_shape = layout_analysis.GetShape(reduce_hlo->getOperand(0));
3927     Shape reduction_kept_element_shape = ShapeUtil::FilterDimensions(
3928         [&](int64_t dim) {
3929           return !absl::c_linear_search(
3930               reduce_hlo->getAttrOfType<mlir::DenseIntElementsAttr>(
3931                   "dimensions"),
3932               dim);
3933         },
3934         operand_shape);
3935 
3936     for (int partial_result_idx = 0;
3937          partial_result_idx < reduction_info.GetNumPartialResults();
3938          ++partial_result_idx) {
3939       llvm::Value* untransposed_output_linear_address =
3940           GetUntransposedOutputLinearAddress(
3941               &b_,
3942               start_offset.AddOffsetToDim(constant(partial_result_idx), kDimX,
3943                                           &b_),
3944               reduction_info);
3945 
3946       // A reduction is allowed to transpose its output.  For example, suppose
3947       // we are reducing the second dimension of f32[10,20,30]{3,2,1}.  We are
3948       // allowed to produce as output either f32[10,30]{1,0} (no transpose) or
3949       // f32[10,30]{0,1} (transposing the two output dims).
3950       //
3951       // At this point in the function we have a "partial sum" of input elements
3952       // (stored in partial_result_addresses), and we need to accumulate it into
3953       // the correct output element.
3954       IrArray::Index element_index(
3955           /*linear=*/untransposed_output_linear_address,
3956           reduction_kept_element_shape, &b_);
3957       IrArray::Index output_index(element_index.multidim(),
3958                                   output_array.GetShape(),
3959                                   element_index.GetType());
3960       llvm::Value* output_address = output_array.EmitArrayElementAddress(
3961           output_index, &b_, "output_element_address");
3962       llvm::Value* current_output = b_.CreateInBoundsGEP(
3963           partial_result_addresses[reduction_idx],
3964           {constant(partial_result_idx)}, "current_output");
3965 
3966       llvm::Type* element_type =
3967           partial_result_addresses[reduction_idx]->getType()->getElementType();
3968       if (reduction_info.IsRowReduction()) {
3969         EmitEpilogueForRowReduction(reducers[reduction_idx], thread_id_info,
3970                                     reduction_info, element_type, index_ty,
3971                                     current_output, output_address,
3972                                     reduction_idx, partial_result_idx);
3973       } else {
3974         EmitEpilogueForColumnReduction(
3975             reducers[reduction_idx], thread_id_info, reduction_info,
3976             element_type, index_ty, current_output, output_address,
3977             reduction_idx, partial_result_idx, tiling_kernel_info);
3978       }
3979     }
3980   }
3981 }
3982 
EmitBlockId()3983 llvm::Value* IrEmitterUnnested::EmitBlockId() {
3984   return gpu::EmitCallToTargetIntrinsic(gpu::TargetIntrinsicID::kBlockIdx, {},
3985                                         {}, &b_);
3986 }
3987 
EmitPrintfWithThreadId(absl::string_view fmt,absl::Span<llvm::Value * const> arguments,absl::optional<int64> thread_id_filter,absl::optional<int64> block_id_filter)3988 void IrEmitterUnnested::EmitPrintfWithThreadId(
3989     absl::string_view fmt, absl::Span<llvm::Value* const> arguments,
3990     absl::optional<int64> thread_id_filter,
3991     absl::optional<int64> block_id_filter) {
3992   llvm::Value* thread_id = EmitThreadId(1024, b_.getInt32Ty());
3993   llvm::Value* block_id = EmitBlockId();
3994   std::vector<llvm::Value*> updated_arguments = {thread_id, block_id};
3995   updated_arguments.insert(updated_arguments.end(), arguments.begin(),
3996                            arguments.end());
3997   llvm::Value* constraint = b_.getTrue();
3998   if (thread_id_filter) {
3999     constraint = b_.CreateAnd(
4000         constraint, b_.CreateICmpEQ(thread_id, b_.getInt32(*thread_id_filter)));
4001   }
4002   if (block_id_filter) {
4003     constraint = b_.CreateAnd(
4004         constraint, b_.CreateICmpEQ(block_id, b_.getInt32(*block_id_filter)));
4005   }
4006   KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
4007   ksl.If(constraint, [&] {
4008     xla::gpu::EmitPrintf(absl::StrCat("[TID=%d,BID=%d] ", fmt, "\n"),
4009                          updated_arguments, &b_);
4010   });
4011 }
CastSharedToGlobal(llvm::Value * input,llvm::Twine name)4012 llvm::Value* IrEmitterUnnested::CastSharedToGlobal(llvm::Value* input,
4013                                                    llvm::Twine name) {
4014   return b_.CreateAddrSpaceCast(
4015       input,
4016       llvm::PointerType::get(input->getType()->getPointerElementType(),
4017                              /*AddressSpace=*/0),
4018       name);
4019 }
4020 
EmitEpilogueForRowReduction(HloComputation * reducer,const IrEmitterUnnested::ThreadIdInfo & thread_id_info,const ReductionCodegenState & reduction_info,llvm::Type * element_type,llvm::Type * index_ty,llvm::Value * current_output,llvm::Value * output_address,int reduction_idx,int partial_result_idx)4021 void IrEmitterUnnested::EmitEpilogueForRowReduction(
4022     HloComputation* reducer,
4023     const IrEmitterUnnested::ThreadIdInfo& thread_id_info,
4024     const ReductionCodegenState& reduction_info, llvm::Type* element_type,
4025     llvm::Type* index_ty, llvm::Value* current_output,
4026     llvm::Value* output_address, int reduction_idx, int partial_result_idx) {
4027   auto constant = [&](uint64 c) -> llvm::Constant* {
4028     return llvm::ConstantInt::get(index_ty, c);
4029   };
4030   auto is_zero = [&](llvm::Value* value) {
4031     return b_.CreateICmpEQ(value, constant(0));
4032   };
4033   llvm::GlobalVariable* shared_cache =
4034       reduction_info.GetSharedCache()[reduction_idx];
4035   KernelSupportLibrary ksl(&b_);
4036   const KernelMappingScheme& mapping_scheme =
4037       reduction_info.GetKernelMappingScheme();
4038 
4039   EmitFullWarpShuffleDownLoopForReduce(reducer, element_type, current_output,
4040                                        mapping_scheme.GetThreadsPerBlock());
4041   llvm::Value* warp_id =
4042       b_.CreateUDiv(thread_id_info.thread_id_x, constant(kWarpSize));
4043   ksl.If("intra_warp_reduce_write", is_zero(thread_id_info.lane_id), [&] {
4044     llvm::Value* shmem_output_addr = CastSharedToGlobal(b_.CreateInBoundsGEP(
4045         shared_cache, {b_.getInt32(0), constant(partial_result_idx), warp_id}));
4046     b_.CreateStore(b_.CreateLoad(current_output), shmem_output_addr);
4047   });
4048 
4049   // TODO(cheshire): Don't we want to sync it once for everything in the
4050   // output? Not once per each?
4051   EmitSyncThreads();
4052   ksl.If("inter_warp_reduce", is_zero(warp_id), [&] {
4053     llvm::Value* block_accum_addr = CastSharedToGlobal(b_.CreateInBoundsGEP(
4054         shared_cache, {b_.getInt32(0), constant(partial_result_idx),
4055                        thread_id_info.lane_id}));
4056     llvm::Value* initial_value =
4057         reduction_info.GetInitialValues()[reduction_idx];
4058     llvm::Value* initial_value_addr =
4059         CastSharedToGlobal(llvm_ir::EmitAllocaAtFunctionEntry(
4060             element_type, "initial_value_addr", &b_));
4061     b_.CreateStore(initial_value, initial_value_addr);
4062 
4063     llvm::Value* warp_exists =
4064         b_.CreateICmpULT(thread_id_info.thread_id_x,
4065                          constant(mapping_scheme.GetNumThreadsX() / kWarpSize));
4066 
4067     llvm::Value* selected_value =
4068         b_.CreateSelect(warp_exists, block_accum_addr, initial_value_addr);
4069 
4070     EmitFullWarpShuffleDownLoopForReduce(reducer, element_type,
4071                                          /*block_accum_addr*/ selected_value,
4072                                          mapping_scheme.GetThreadsPerBlock());
4073     ksl.If("reduction_write_output", is_zero(thread_id_info.thread_id_x), [&] {
4074       if (reduction_info.IsRaceFree()) {
4075         VLOG(10) << "Using deterministic reductions: writing out "
4076                     "the value directly";
4077         b_.CreateStore(b_.CreateLoad(block_accum_addr, "output"),
4078                        output_address);
4079       } else {
4080         TF_CHECK_OK(EmitAtomicOperationForNestedComputation(
4081             *reducer, output_address, block_accum_addr));
4082       }
4083     });
4084   });
4085 }
4086 
EmitEpilogueForColumnReduction(HloComputation * reducer,const IrEmitterUnnested::ThreadIdInfo & thread_id_info,const ReductionCodegenState & reduction_info,llvm::Type * element_type,llvm::Type * index_ty,llvm::Value * current_output,llvm::Value * output_address,int reduction_idx,int partial_result_idx,const TilingKernelInfo & tiling_kernel_info)4087 void IrEmitterUnnested::EmitEpilogueForColumnReduction(
4088     HloComputation* reducer,
4089     const IrEmitterUnnested::ThreadIdInfo& thread_id_info,
4090     const ReductionCodegenState& reduction_info, llvm::Type* element_type,
4091     llvm::Type* index_ty, llvm::Value* current_output,
4092     llvm::Value* output_address, int reduction_idx, int partial_result_idx,
4093     const TilingKernelInfo& tiling_kernel_info) {
4094   KernelSupportLibrary ksl(&b_);
4095   llvm::GlobalVariable* shared_cache =
4096       reduction_info.GetSharedCache()[reduction_idx];
4097   auto constant = [&](uint64 c) -> llvm::Constant* {
4098     return llvm::ConstantInt::get(index_ty, c);
4099   };
4100   auto is_zero = [&](llvm::Value* value) {
4101     return b_.CreateICmpEQ(value, constant(0));
4102   };
4103   const KernelMappingScheme& mapping_scheme =
4104       reduction_info.GetKernelMappingScheme();
4105   llvm::Value* shmem_output_addr = CastSharedToGlobal(
4106       b_.CreateInBoundsGEP(
4107           shared_cache,
4108           {b_.getInt32(0), constant(partial_result_idx),
4109            thread_id_info.thread_id_x, thread_id_info.thread_id_y}),
4110       "shmem_output_address");
4111   llvm::Value* current_output_value = b_.CreateLoad(current_output);
4112   b_.CreateStore(current_output_value, shmem_output_addr);
4113 
4114   EmitSyncThreads();
4115 
4116   // Get transposed element from shared memory.
4117   llvm::Value* shmem_transposed_addr = CastSharedToGlobal(b_.CreateInBoundsGEP(
4118       shared_cache,
4119       {b_.getInt32(0), constant(partial_result_idx), thread_id_info.thread_id_y,
4120        thread_id_info.thread_id_x},
4121       "shmem_transposed_addr"));
4122 
4123   EmitFullWarpShuffleDownLoopForReduce(reducer, element_type,
4124                                        shmem_transposed_addr,
4125                                        mapping_scheme.GetThreadsPerBlock());
4126 
4127   // Some warps in the block are completely outside of the bound of the
4128   // tensor, so they should not write any output at all.
4129   llvm::Value* has_output = b_.CreateAnd(
4130       b_.CreateICmpULT(
4131           GetStartOffsetX(mapping_scheme, thread_id_info.thread_id_y, index_ty,
4132                           &b_),
4133           tiling_kernel_info.output_tile_bounds[kDimX]),
4134       b_.CreateICmpULT(thread_id_info.thread_id_x,
4135                        tiling_kernel_info.output_tile_bounds[kDimY]));
4136 
4137   ksl.If("reduction_write_output",
4138          b_.CreateAnd(has_output, is_zero(thread_id_info.lane_id)), [&] {
4139            if (reduction_info.IsRaceFree()) {
4140              VLOG(10) << "Using deterministic reductions: writing out "
4141                          "the value directly";
4142              b_.CreateStore(
4143                  b_.CreateLoad(shmem_transposed_addr, "output_value"),
4144                  output_address);
4145            } else {
4146              TF_CHECK_OK(EmitAtomicOperationForNestedComputation(
4147                  *reducer, output_address, shmem_transposed_addr));
4148            }
4149          });
4150 }
4151 
EmitTileElementForReduction(mlir::Operation * unnested_hlo,const Shape & reduction_operand_shape,absl::Span<const int> instr_index_group,HloComputation * fused_computation,FusedIrEmitter * fused_emitter,absl::Span<const llvm_ir::IrArray> result_ir_arrays,absl::Span<HloComputation * const> reducers,const llvm_ir::IrArray::Index & index,const ReductionCodegenState & reduction_info,int64_t x_iter_num,const FusionLayoutAnalysis & layout_analysis)4152 void IrEmitterUnnested::EmitTileElementForReduction(
4153     mlir::Operation* unnested_hlo, const Shape& reduction_operand_shape,
4154     absl::Span<const int> instr_index_group, HloComputation* fused_computation,
4155     FusedIrEmitter* fused_emitter,
4156     absl::Span<const llvm_ir::IrArray> result_ir_arrays,
4157     absl::Span<HloComputation* const> reducers,
4158     const llvm_ir::IrArray::Index& index,
4159     const ReductionCodegenState& reduction_info, int64_t x_iter_num,
4160     const FusionLayoutAnalysis& layout_analysis) {
4161   VLOG(10) << "Emit tile element for reduce " << MlirToString(unnested_hlo);
4162   int partial_result_index = reduction_info.IsRowReduction() ? 0 : x_iter_num;
4163 
4164   InlinedVector<llvm_ir::ElementGenerator, 1> input_gens;
4165   std::vector<std::pair<llvm_ir::ElementGenerator, int>> extra_output_gens;
4166 
4167   // Construct the ElementGenerator for each reduction and extra output in the
4168   // the group of output instructions.
4169   for (int index : instr_index_group) {
4170     mlir::Operation* inst = GetReduceFromUnnestedMlir(unnested_hlo, index);
4171 
4172     const HloInstruction* hlo = fused_computation->root_instruction();
4173     if (hlo->opcode() == HloOpcode::kTuple) {
4174       hlo = hlo->operand(index);
4175     }
4176     if (IsReductionFromOrToContiguousDimensions(inst, layout_analysis)) {
4177       input_gens.push_back(*fused_emitter->GetGenerator(hlo->operand(0)));
4178     } else {
4179       extra_output_gens.emplace_back(*fused_emitter->GetGenerator(hlo), index);
4180     }
4181   }
4182 
4183   IrArray::Index input_index =
4184       GetUnnormalizedIndex(index, reduction_operand_shape, &b_,
4185                            reduction_info.GetKernelMappingScheme());
4186   // Clear the linear index field of the IrArray::Index to enable the use of
4187   // GetElementPointer with array types. This enables the vectorization of
4188   // the computation for different partial results. Use this index if
4189   // 'num_partial_results > 1'.
4190   int num_partial_results = reduction_info.GetNumPartialResults();
4191   auto index_without_linear = IrArray::Index(
4192       input_index.multidim(), reduction_operand_shape, input_index.GetType());
4193 
4194   // Emit code to generate the input and perform the reduction computation for
4195   // each reduction instruction.
4196   for (int i = 0; i < reducers.size(); i++) {
4197     llvm::AllocaInst* input_address =
4198         reduction_info.GetReductionInputAddresses()[i];
4199     llvm::AllocaInst* partial_reduction_result_address =
4200         reduction_info.GetPartialResultAddresses()[i];
4201     llvm::Value* const input_ir_value =
4202         input_gens[i](num_partial_results > 1 ? index_without_linear
4203                                               : input_index)
4204             .ValueOrDie();
4205     Store(input_ir_value, input_address);
4206     llvm::Value* partial_result_address = InBoundsGEP(
4207         partial_reduction_result_address, {b_.getInt32(partial_result_index)});
4208     TF_CHECK_OK(EmitCallToNestedComputation(
4209         *reducers[i], {partial_result_address, input_address},
4210         partial_result_address));
4211   }
4212 
4213   // Emit code to generate the output for the non-reduction instructions in the
4214   // fusion, if any.
4215   TF_CHECK_OK(EmitExtraOutputsForReduce(
4216       result_ir_arrays, input_index,
4217       /*use_linear_index=*/num_partial_results == 1, extra_output_gens));
4218 }
4219 
EmitThreadId(int64_t threads_per_block,llvm::Type * index_ty)4220 llvm::Value* IrEmitterUnnested::EmitThreadId(int64_t threads_per_block,
4221                                              llvm::Type* index_ty) {
4222   // Calculate (y, x) coordinates respectively in the 2D view of thread block,
4223   // defined by (num_thread_y, num_thread_x) from thread_id.
4224   llvm::CallInst* thread_id_raw = gpu::EmitCallToTargetIntrinsic(
4225       gpu::TargetIntrinsicID::kThreadIdx, {}, {}, &b_);
4226   llvm_ir::AddRangeMetadata(0, threads_per_block, thread_id_raw);
4227   return b_.CreateIntCast(thread_id_raw, index_ty,
4228                           /*isSigned=*/true, "thread.id.x");
4229 }
4230 
EmitThreadIdInfo(int64_t threads_per_block,llvm::Type * index_ty,int64_t num_threads_x)4231 IrEmitterUnnested::ThreadIdInfo IrEmitterUnnested::EmitThreadIdInfo(
4232     int64_t threads_per_block, llvm::Type* index_ty, int64_t num_threads_x) {
4233   auto constant = [&](uint64 c) -> llvm::Constant* {
4234     return llvm::ConstantInt::get(index_ty, c);
4235   };
4236   llvm::Value* thread_id = EmitThreadId(threads_per_block, index_ty);
4237   llvm::Value* num_threads_x_v = constant(num_threads_x);
4238   return {
4239       /*thread_id=*/thread_id,
4240       /*thread_id_x=*/b_.CreateURem(thread_id, num_threads_x_v, "thread_id.x"),
4241       /*thread_id_y=*/b_.CreateUDiv(thread_id, num_threads_x_v, "thread_id.y"),
4242       /*lane_id=*/b_.CreateURem(thread_id, constant(kWarpSize), "lane_id")};
4243 }
4244 
EmitTilingKernel(const KernelMappingScheme & mapping_scheme,llvm::Type * index_ty,const TileElementGenerator & tile_element_generator)4245 IrEmitterUnnested::TilingKernelInfo IrEmitterUnnested::EmitTilingKernel(
4246     const KernelMappingScheme& mapping_scheme, llvm::Type* index_ty,
4247     const TileElementGenerator& tile_element_generator) {
4248   absl::Span<const int64> dims_in_elems = mapping_scheme.GetDimsInElems();
4249   std::vector<int64> dims_in_blocks = {
4250       CeilOfRatio(dims_in_elems[0], mapping_scheme.GetTileSizeZ()),
4251       CeilOfRatio(dims_in_elems[1], mapping_scheme.GetTileSizeY()),
4252       CeilOfRatio(dims_in_elems[2], mapping_scheme.GetTileSizeX())};
4253   auto constant = [&](uint64 c) -> llvm::Constant* {
4254     return llvm::ConstantInt::get(index_ty, c);
4255   };
4256 
4257   IrEmitterUnnested::ThreadIdInfo thread_id_info =
4258       EmitThreadIdInfo(mapping_scheme.GetThreadsPerBlock(), index_ty,
4259                        mapping_scheme.GetNumThreadsX());
4260 
4261   KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
4262 
4263   const IrArray::Index block_coords = [&] {
4264     llvm::Value* block_id = EmitBlockId();
4265     llvm_ir::AddRangeMetadata(0, mapping_scheme.GetNumberOfBlocks(),
4266                               llvm::cast<llvm::Instruction>(block_id));
4267     llvm::Value* linear_block_id =
4268         b_.CreateIntCast(block_id, index_ty, /*isSigned=*/true, "block.id.x");
4269     IrArray::Index starting_block(linear_block_id,
4270                                   ShapeUtil::MakeShapeWithDescendingLayout(
4271                                       PRED /*arbitrary*/, dims_in_blocks),
4272                                   &b_);
4273 
4274     std::vector<llvm::Value*> multidim = {
4275         b_.CreateMul(starting_block[0], constant(mapping_scheme.GetTileSizeZ()),
4276                      "block_origin.z"),
4277         starting_block[1], starting_block[2]};
4278     return IrArray::Index(multidim, dims_in_blocks, index_ty);
4279   }();
4280 
4281   std::array<llvm::Value*, 3> output_tile_bounds;
4282   for (int i = kDimY; i < kDimTot; ++i) {
4283     int64_t tile_size_for_dim = mapping_scheme.GetTileSizeFor(i);
4284     // Only last row or column may not have full size.
4285     llvm::Value* is_last =
4286         b_.CreateICmpEQ(block_coords[i], constant(dims_in_blocks[i] - 1));
4287     int64_t partial_row =
4288         dims_in_elems[i] - (dims_in_blocks[i] - 1) * tile_size_for_dim;
4289     output_tile_bounds[i] =
4290         b_.CreateSelect(is_last, constant(partial_row),
4291                         constant(tile_size_for_dim), "tile_bound");
4292   }
4293 
4294   IrArray::Index tile_origin = [&] {
4295     std::vector<llvm::Value*> elem_multi_index = block_coords.multidim();
4296     llvm::Type* index_ty = block_coords.GetType();
4297     for (int i = kDimY; i < kDimTot; ++i) {
4298       elem_multi_index[i] = b_.CreateMul(
4299           block_coords[i],
4300           llvm::ConstantInt::get(index_ty, mapping_scheme.GetTileSizeFor(i)),
4301           "tile_origin." + std::to_string(i));
4302     }
4303     return IrArray::Index(elem_multi_index, mapping_scheme.GetDimsInElems(),
4304                           index_ty);
4305   }();
4306 
4307   auto emit_tile = [&](const IrArray::Index& tile) {
4308     tile_element_generator(thread_id_info, tile, "output",
4309                            output_tile_bounds[1], output_tile_bounds[2], &ksl);
4310   };
4311 
4312   if (mapping_scheme.GetTileSizeZ() == 1) {
4313     emit_tile(tile_origin);
4314   } else {
4315     llvm::Value* starting_tile_index_for_dim = tile_origin[kDimZ];
4316     llvm::Value* block_size_for_dim = constant(mapping_scheme.GetTileSizeZ());
4317     llvm::Value* block_id_for_dim =
4318         b_.CreateUDiv(starting_tile_index_for_dim, block_size_for_dim);
4319     llvm::Value* last_block_for_dim = constant(dims_in_blocks[kDimZ] - 1);
4320     llvm::Value* last_block_size_for_dim =
4321         constant(dims_in_elems[kDimZ] -
4322                  (dims_in_blocks[kDimZ] - 1) * mapping_scheme.GetTileSizeZ());
4323 
4324     llvm::Value* num_tiles_in_block =
4325         b_.CreateSelect(b_.CreateICmpEQ(last_block_for_dim, block_id_for_dim),
4326                         last_block_size_for_dim, block_size_for_dim);
4327     ksl.For("loop_z",
4328             /*start=*/constant(0),
4329             /*end=*/num_tiles_in_block,
4330             /*step=*/1, [&](llvm::Value* block_dim_induction_var) {
4331               IrArray::Index tile_index = tile_origin.AddOffsetToDim(
4332                   block_dim_induction_var, kDimZ, &b_);
4333               emit_tile(tile_index);
4334             });
4335   }
4336   return {output_tile_bounds, tile_origin};
4337 }
4338 
EmitSyncThreads()4339 llvm::CallInst* IrEmitterUnnested::EmitSyncThreads() {
4340   return EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, &b_);
4341 }
4342 
4343 // Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose
4344 // algorithm to improve the memory access patterns for the input parameters
4345 // with a shape that is a 0-2-1 transpose of the output tensor shape. The caller
4346 // is responsible for making sure that it is safe to apply the shared memory
4347 // transpose on the input parameters.
4348 //
4349 //
4350 // For the purpose of tiling, the output tensors have a logical shape of three
4351 // components 0-2-1 while the relevant input parameters have a logical shape
4352 // of three components 0-1-2 in the order major to minor. The x- and y-
4353 // dimensions of the tensors are tiled in square tiles with an edge length
4354 // `kTileSize`. Each thread block of `kTileSize` x `kNumRows` threads
4355 // transposes one tile: each thread copies kTileSize/kNumRows elements from
4356 // the input to a shared memory tile, then the otherwise "regular HLO kernel"
4357 // reads from the shared memory instead of the original input.
4358 //
4359 // This is similar to the following CUDA algorithm in TensorFlow:
4360 // https://goo.gl/MStRV6.
4361 //
4362 // `kTileSize` should usually be same as warp size. We currently choose 32 for
4363 // `kTileSize` and 4 for `kNumRows`. The CUDA algorithm uses 8 for `kNumRows`.
4364 //
4365 // TODO(b/33320379): Here each block transposes 1 tile. It may be more
4366 // efficient to launch fewer blocks so each transposes many tiles.
EmitHlo021Tile(mlir::Operation * op,Thunk * kernel_thunk,const MlirEmitterContext & context,absl::Span<const llvm_ir::IrArray> operand_arrays,absl::Span<const llvm_ir::IrArray> output_arrays,absl::Span<const int64> reduced_output_dims,absl::Span<const int64> tiled_param_ids,const KernelMappingScheme & mapping_scheme,const LaunchDimensions & launch_dimensions)4367 void IrEmitterUnnested::EmitHlo021Tile(
4368     mlir::Operation* op, Thunk* kernel_thunk, const MlirEmitterContext& context,
4369     absl::Span<const llvm_ir::IrArray> operand_arrays,
4370     absl::Span<const llvm_ir::IrArray> output_arrays,
4371     absl::Span<const int64> reduced_output_dims,
4372     absl::Span<const int64> tiled_param_ids,
4373     const KernelMappingScheme& mapping_scheme,
4374     const LaunchDimensions& launch_dimensions) {
4375   std::string name = mlir::GetNameFromLoc(op->getLoc());
4376 
4377   llvm::Type* index_type =
4378       GetIndexTypeForKernel(op, launch_dimensions.launch_bound(), &b_);
4379   std::vector<IrArray> param_arrays;
4380 
4381   // For each tiled parameter, cast its input IrArray to the corresponding
4382   // reduced shape and keep the reduced shape live during IR emission.
4383   std::vector<IrArray> param_in_reduced_shape_arrays;
4384   std::vector<llvm::Value*> param_shmem_buffers(context.operand_shapes.size(),
4385                                                 nullptr);
4386 
4387   auto get_shared_memory_buffer = [&](llvm::Type* elem_ty,
4388                                       absl::string_view buffer_name) {
4389     // For Nvidia GPUs, the warp size is 32 threads and the shared memory bank
4390     // is organized into 32-way. We usually use the warp size or a multiplier or
4391     // a the warp size as the size for tiling. This may cause all elements in
4392     // the same column of a tile use the same memory bank and therefore shared
4393     // memory bank conflicts. Adding 1 to the minor dimension of the shared
4394     // memory buffer can reduce such shared memory bank conflicts.
4395     llvm::Type* buffer_type = llvm::ArrayType::get(
4396         llvm::ArrayType::get(elem_ty, mapping_scheme.GetTileSizeX() + 1),
4397         mapping_scheme.GetTileSizeY());
4398     return llvm_ir::AllocateSharedMemoryTile(b_.GetInsertBlock()->getModule(),
4399                                              buffer_type, buffer_name);
4400   };
4401 
4402   for (int64_t id = 0; id < context.operand_shapes.size(); id++) {
4403     const Shape& param_shape = context.operand_shapes[id];
4404     param_arrays.push_back(operand_arrays[id]);
4405 
4406     if (absl::c_linear_search(tiled_param_ids, id)) {
4407       param_shmem_buffers[id] = get_shared_memory_buffer(
4408           llvm_ir::PrimitiveTypeToIrType(param_shape.element_type(), module_),
4409           IrName(name, StrCat("tile", id)));
4410       VLOG(3) << "Added shmem buffer for parameter " << id << ": "
4411               << llvm_ir::DumpToString(*param_shmem_buffers[id]);
4412       Shape reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout(
4413           param_shape.element_type(), Permute(reduced_output_dims, {0, 2, 1}));
4414       param_in_reduced_shape_arrays.push_back(
4415           param_arrays[id].CastToShape(reduced_shape, &b_));
4416     } else {
4417       param_in_reduced_shape_arrays.push_back(IrArray());
4418     }
4419   }
4420 
4421   EmitElementFunction element_generator =
4422       [&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc,
4423           llvm::Value* x_loc, int64_t x_iter_num) {
4424         auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op);
4425         EmitTileElementForFusion(fusion, operand_arrays, output_arrays, index,
4426                                  mapping_scheme, y_loc, x_loc,
4427                                  param_shmem_buffers);
4428       };
4429 
4430   TileElementGenerator tile_generator =
4431       [&](const ThreadIdInfo& thread_id_info, const IrArray::Index& index,
4432           const string& loop_name, llvm::Value* tile_height,
4433           llvm::Value* tile_width, KernelSupportLibrary* ksl) {
4434         // If shared memory transpose is needed, wait for all threads to reach
4435         // this point, lest we copy a value from tile to output before the other
4436         // thread copies it from input to tile. This is `__syncthreads` in CUDA.
4437         if (!tiled_param_ids.empty()) {
4438           // Calculate the input tile origin from the output tile origin.
4439           const IrArray::Index input_tile_origin(
4440               Permute(index.multidim(), {0, 2, 1}),
4441               Permute(index.dims(), {0, 2, 1}), index.GetType());
4442 
4443           // Copy input parameter values to shared memory buffers:
4444           // tile[thread_id_y, thread_id_x] = input[index]
4445           // Note that tile_width and tile_height are flipped here because we
4446           // are reading a transposed tile.
4447           EmitTile(mapping_scheme, input_tile_origin, "input", ksl,
4448                    thread_id_info, tile_width, tile_height,
4449                    [&](const IrArray::Index& index, llvm::Value* y_loc,
4450                        llvm::Value* x_loc, int64_t /*x_iter_num*/) {
4451                      for (int64_t id : tiled_param_ids) {
4452                        IrArray& input_in_logical_shape =
4453                            param_in_reduced_shape_arrays.at(id);
4454 
4455                        llvm::Value* shmem_buffer = param_shmem_buffers.at(id);
4456                        llvm::Value* zero =
4457                            llvm::ConstantInt::get(index_type, 0);
4458                        // TODO(jlebar): Add AA metadata to this store.  Tile
4459                        // buffers are global variables, so LLVM can't infer much
4460                        // about it.
4461                        auto value = input_in_logical_shape.EmitReadArrayElement(
4462                            index, &b_, "input_element");
4463                        auto addr = GEP(shmem_buffer, {zero, y_loc, x_loc});
4464                        Store(value, addr);
4465                      }
4466                    });
4467 
4468           // Wait for all threads to reach this point using `__syncthreads` in
4469           // CUDA.
4470           EmitSyncThreads();
4471         }
4472 
4473         EmitTile(mapping_scheme, index, loop_name, ksl, thread_id_info,
4474                  tile_height, tile_width, element_generator);
4475         bool block_contains_multi_tiles = mapping_scheme.GetTileSizeZ() > 1;
4476 
4477         // If a tile block contains multiple tiles and shared memory buffers are
4478         // used, we need to wait for all threads to finish using the shared
4479         // memory buffer for the current tile before we move on to process the
4480         // next tile and overwrite the shared memory buffers.
4481         if (block_contains_multi_tiles && !tiled_param_ids.empty()) {
4482           EmitSyncThreads();
4483         }
4484       };
4485 
4486   EmitTilingKernel(mapping_scheme, index_type, tile_generator);
4487 }
4488 
4489 namespace {
4490 
4491 // A recursive function to inspect the users of a parameter to determine
4492 // whether it's safe for a parameter to participate in a shared-memory
4493 // transpose.
4494 //
4495 // Consider a fusion parameter P for which we might want to use a shmem
4496 // transpose.  If we do, we use a GPU thread block to preload a tile of P with
4497 // indices [z, y..y+31, x..x+31] to compute an output tile with the same indices
4498 // cooperatively, where z, y, x are the indices for the normalized input/output
4499 // tensor (see the document for FindTranspose021 for the definition of
4500 // normalized tensor for 0-2-1 transpose). This shmem transpose implementation
4501 // requires that the computation of the output tile only read elements within
4502 // the preload tile. If this is not true, we can't use a shmem transpose for P.
4503 //
4504 // If the computation of output element [z, y, x] only requires the element of
4505 // P with the same indices, the shmem transpose implementation can be applied
4506 // to P safely. This is a sufficient but not necessary condition. We check all
4507 // the transitive users of P to see if we can find a user that may cause an
4508 // exception to the situation. If such a user is not found, we conclude that P
4509 // is safe for shmem transpose.
4510 //
4511 // This is trivially true for elementwise operations and some "data-movement"
4512 // ops like kTuple. However, it's not true for operations that can change the
4513 // dimensions of the inputs (e.g. pad, slice) and bitcast operation.
4514 // For example:
4515 //
4516 // fused_computation {
4517 //   param_0 = f32[64,64]{1,0} parameter(0)
4518 //   ROOT bitcast = f32[64,64]{0,1} bitcast(param_0)
4519 // }
4520 // The output element at logical address [0, 63] depends on the input element
4521 // at logical address [63, 0], which would not be within the shared-memory
4522 // block.
4523 //
4524 // TODO(bixia): In order to extend this for kInput fusion, that is reduction
4525 // with transpose, we only need to end the use-chain checking with the input of
4526 // a reduce operations. In this case, the above description on "output" apply
4527 // to the result of such a use-chain, which provides the input to the reduce
4528 // operation.
IsInstructionSafeForShmemTranspose(mlir::Operation * op)4529 bool IsInstructionSafeForShmemTranspose(mlir::Operation* op) {
4530   if (mlir::isa<mlir::memref::TensorStoreOp>(op)) {
4531     return true;
4532   }
4533 
4534   HloOpcode opcode;
4535   if (mlir::isa<mlir::memref::TensorLoadOp>(op)) {
4536     opcode = HloOpcode::kParameter;
4537   } else {
4538     opcode = *MhloToHloOpcode(op);
4539   }
4540   if (HloInstruction::IsOpElementwise(opcode)) {
4541     for (mlir::Value v : op->getResults()) {
4542       for (mlir::OpOperand use : v.getUsers()) {
4543         if (!IsInstructionSafeForShmemTranspose(use.getOwner())) {
4544           return false;
4545         }
4546       }
4547     }
4548     return true;
4549   }
4550 
4551   switch (opcode) {
4552     // Non-elementwise instructions that don't cause the shmem transpose
4553     // to be unsafe, including the instructions that don't currently fuse.
4554     case HloOpcode::kGetDimensionSize:
4555       // The result of the operation doesn't rely on the content of the
4556       // tensor. As such, there is no need to further inspect its users.
4557       return true;
4558     case HloOpcode::kGetTupleElement:
4559     case HloOpcode::kMap:
4560     case HloOpcode::kParameter:
4561     case HloOpcode::kTuple:
4562     case HloOpcode::kTupleSelect:
4563       for (mlir::Value v : op->getResults()) {
4564         for (mlir::OpOperand use : v.getUsers()) {
4565           if (!IsInstructionSafeForShmemTranspose(use.getOwner())) {
4566             return false;
4567           }
4568         }
4569       }
4570       return true;
4571 
4572     default:
4573       return false;
4574   }
4575 }
4576 
4577 // Given a group of input parameters that are 0-2-1 transpose of the outputs of
4578 // a fusion kernel, returns the input parameters that are safe for the shared
4579 // memory transpose implementation.
4580 //
4581 // When a tile based shared memory transpose is used to implement an input with
4582 // 0-2-1 transpose, we preload a tile of the input elements
4583 // [z, y..y+31, x..x+31] to compute the output tile elements of the same
4584 // indices. Preloading the input tile this way is only safe when the computation
4585 // of the output tile elements do not need any input element outside the
4586 // preloaded tile. We inspect all the transitive users of the input parameter
4587 // up to the fusion root instruction to see if we can find any instruction
4588 // that can make preloading the input tile unsafe.
FilterInputsForShmemTranspose(mlir::lmhlo::FusionOp fusion,std::vector<int64> input_ids)4589 std::vector<int64> FilterInputsForShmemTranspose(mlir::lmhlo::FusionOp fusion,
4590                                                  std::vector<int64> input_ids) {
4591   std::vector<mlir::Value> params = ToStdVector(fusion.getFusionParameters());
4592 
4593   std::vector<int64> filtered_input_ids;
4594   for (int64_t input_id : input_ids) {
4595     mlir::Value input = params.at(input_id);
4596     if (IsInstructionSafeForShmemTranspose(input.getDefiningOp())) {
4597       filtered_input_ids.push_back(input_id);
4598     }
4599   }
4600   return filtered_input_ids;
4601 }
4602 
4603 }  // namespace
4604 
CheckAndEmitHloWithTile021(mlir::Operation * op)4605 StatusOr<bool> IrEmitterUnnested::CheckAndEmitHloWithTile021(
4606     mlir::Operation* op) {
4607   CHECK(mlir::isa<mlir::lmhlo::FusionOp>(op));
4608 
4609   MlirEmitterContext context;
4610   context.SetOperation(op);
4611 
4612   // If the output_shape is reduced to 021 shape, find all the parameters of
4613   // the HLO that are in the corresponding 012 shape.
4614   std::vector<int64> params_012;
4615   optional<std::vector<int64>> reduced_dims_021;
4616   for (int64_t operand_idx = 0; operand_idx < context.operand_shapes.size();
4617        ++operand_idx) {
4618     const Shape& operand_shape = context.operand_shapes[operand_idx];
4619     auto find_transpose_result =
4620         ShapeUtil::FindTranspose021(operand_shape, context.output_shapes[0]);
4621     if (!find_transpose_result.has_value()) {
4622       continue;
4623     }
4624     const std::vector<int64>& curr_reduced_dims_021 = *find_transpose_result;
4625     if (!reduced_dims_021.has_value()) {
4626       reduced_dims_021 = curr_reduced_dims_021;
4627     }
4628     if (!absl::c_equal(*reduced_dims_021, curr_reduced_dims_021)) {
4629       // There is more than one possible transpose. Instead of picking one
4630       // transpose, we simply give up here.
4631       return false;
4632     }
4633     params_012.push_back(operand_idx);
4634   }
4635 
4636   if (!reduced_dims_021.has_value()) {
4637     return false;
4638   }
4639 
4640   if ((*reduced_dims_021)[1] < kMinDimensionToTransposeTiled ||
4641       (*reduced_dims_021)[2] < kMinDimensionToTransposeTiled) {
4642     return false;
4643   }
4644 
4645   if (auto fusion_op = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
4646     params_012 = FilterInputsForShmemTranspose(fusion_op, params_012);
4647     if (params_012.empty()) {
4648       return false;
4649     }
4650   }
4651 
4652   // Each of our shared memory tiles has 32*33 elements (so ~4kb, if the
4653   // elements are of size 4 bytes), and CUDA has an architectural limit of
4654   // 48kb shared memory per SM.  (This is increased to 96kb in Volta, but we
4655   // don't use this, in part because it eats into our L1 cache space.)
4656   //
4657   // For correctness we need to ensure that we don't make more than 48kb worth
4658   // of shmem tiles per block.  And for performance, we'd probably like to use
4659   // significantly less, so that we can fit more than one block at a time on a
4660   // gpu core.
4661   //
4662   // We say without benchmarks that we want at least 3 threads/block,
4663   // corresponding to 3 shmem tiles if the elements are 32 bits wide.  We
4664   // choose which params get the shmem transpose treatment arbitrarily; it's
4665   // not clear if there's a Right Choice.
4666   //
4667   // This is only sound if tiled transposes are the only place where we use
4668   // shared memory in fusions.  If in the future other fusible ops use shared
4669   // memory, we'll have to adjust this heuristic.
4670   constexpr int kMinBlocksPerCore = 3;
4671   constexpr int64_t kShmemPerCore = 48 * 1024;
4672   int64_t shmem_used = 0;
4673   for (int64_t i = 0; i < params_012.size(); ++i) {
4674     const Shape& operand_shape = context.operand_shapes[params_012[i]];
4675     shmem_used +=
4676         32 * 33 *
4677         ShapeUtil::ByteSizeOfPrimitiveType(operand_shape.element_type());
4678 
4679     if (kMinBlocksPerCore * shmem_used > kShmemPerCore) {
4680       // Erase this element and everything after it from params_012.
4681       params_012.resize(i);
4682       break;
4683     }
4684   }
4685 
4686   if (params_012.empty()) {
4687     return false;
4688   }
4689 
4690   constexpr int kNumRows = 4;
4691   KernelMappingScheme mapping_scheme(*reduced_dims_021,
4692                                      /*tile_sizes=*/{1, kWarpSize, kWarpSize},
4693                                      /*num_threads_y=*/kNumRows,
4694                                      /*num_threads_x=*/kWarpSize,
4695                                      /*indexing_order=*/kLinearIndexingX,
4696                                      /*vector_size=*/1,
4697                                      /*is_row_contiguous=*/false);
4698   LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(),
4699                                      mapping_scheme.GetThreadsPerBlock());
4700   std::vector<llvm_ir::IrArray> ir_arrays;
4701   TF_ASSIGN_OR_RETURN(
4702       std::unique_ptr<KernelThunk> kernel_thunk,
4703       BuildKernelThunk(op, GetThunkInfo(op), &ir_arrays, launch_dimensions));
4704 
4705   EmitHlo021Tile(
4706       op, kernel_thunk.get(), context,
4707       absl::MakeSpan(ir_arrays).subspan(0, context.operand_shapes.size()),
4708       absl::MakeSpan(ir_arrays).subspan(context.operand_shapes.size()),
4709       *reduced_dims_021, params_012, mapping_scheme, launch_dimensions);
4710   AddThunkToThunkSequence(std::move(kernel_thunk));
4711   return true;
4712 }
4713 
4714 namespace {
4715 
4716 // Returns true if all the transitive users of hlo before hitting users in
4717 // use_chain_endings are elementwise operations.
AreUsersElementwise(mlir::Value value,const absl::flat_hash_set<mlir::Operation * > & use_chain_endings)4718 bool AreUsersElementwise(
4719     mlir::Value value,
4720     const absl::flat_hash_set<mlir::Operation*>& use_chain_endings) {
4721   return absl::c_all_of(value.getUsers(), [&](mlir::OpOperand use) {
4722     mlir::Operation* user = use.getOwner();
4723     CHECK_EQ(1, user->getNumResults());
4724     return use_chain_endings.count(user) ||
4725            (HloInstruction::IsOpElementwise(*MhloToHloOpcode(user)) &&
4726             AreUsersElementwise(user->getResult(0), use_chain_endings));
4727   });
4728 }
4729 
4730 // Returns the number of fusion inputs that have the same dimension as the
4731 // given shape, and involve in only elementwise operations.
NumInputsInvolveInOnlyElementwiseOps(mlir::lmhlo::FusionOp fusion,const Shape & op_shape,const absl::flat_hash_set<mlir::Operation * > & use_chain_endings)4732 int64 NumInputsInvolveInOnlyElementwiseOps(
4733     mlir::lmhlo::FusionOp fusion, const Shape& op_shape,
4734     const absl::flat_hash_set<mlir::Operation*>& use_chain_endings) {
4735   return absl::c_count_if(
4736       fusion.getFusionParameters(), [&](mlir::Value parameter) {
4737         Shape parameter_shape = GetShape(parameter);
4738         return ShapeUtil::SameDimensions(op_shape, parameter_shape) &&
4739                AreUsersElementwise(parameter, use_chain_endings);
4740       });
4741 }
4742 
4743 // Returns the number of fusion inputs that have more elements than the given
4744 // shape.
NumInputsWithMoreElementsThan(mlir::lmhlo::FusionOp fusion,const Shape & shape)4745 int64 NumInputsWithMoreElementsThan(mlir::lmhlo::FusionOp fusion,
4746                                     const Shape& shape) {
4747   int64_t num_elements = ShapeUtil::ElementsIn(shape);
4748   return absl::c_count_if(
4749       fusion.getFusionParameters(), [&](mlir::Value parameter) {
4750         Shape parameter_shape = GetShape(parameter);
4751         return ShapeUtil::ElementsIn(parameter_shape) > num_elements;
4752       });
4753 }
4754 
4755 // The benefit of unrolling a kInput fusion that is a column reduction comes
4756 // from the vectorization of non-reduction fusion outputs and fusion inputs.
4757 // On the other hand, unrolling can also introduce factors that can cause
4758 // the kernel to run slower. This routine uses a simple heuristic to estimate
4759 // the benefit as well as the overhead of unrolling in order to decide whether
4760 // unrolling is beneficial for the given kInput fusion.
IsUnrollingColumnReductionBeneficial(mlir::Operation * unnested_hlo,const Shape & input_shape,int64_t num_kept_minor,const FusionLayoutAnalysis & layout_analysis)4761 bool IsUnrollingColumnReductionBeneficial(
4762     mlir::Operation* unnested_hlo, const Shape& input_shape,
4763     int64_t num_kept_minor, const FusionLayoutAnalysis& layout_analysis) {
4764   // TODO(b/122468062): Need further investigate to see whether we can
4765   // remove the constraint on IsPowerOfTwo.
4766   if (!IsPowerOfTwo(static_cast<uint64>(num_kept_minor))) {
4767     return false;
4768   }
4769 
4770   if (IsReductionFromOrToContiguousDimensions(unnested_hlo, layout_analysis)) {
4771     return true;
4772   }
4773 
4774   auto fusion = mlir::cast<mlir::lmhlo::FusionOp>(unnested_hlo);
4775   int64_t can_be_vectorized = 0;
4776   int64_t cannot_be_vectorized = 0;
4777   auto fusion_results = ToStdVector(fusion.getFusionResults());
4778   absl::flat_hash_set<mlir::Operation*> use_chain_endings;
4779   if (fusion_results.size() == 1) {
4780     if (IsReductionFromOrToContiguousDimensions(
4781             fusion_results[0].getDefiningOp(), layout_analysis)) {
4782       use_chain_endings.insert(fusion_results[0].getDefiningOp());
4783       // Atomic.add of the reduction result can't be vectorized.
4784       cannot_be_vectorized++;
4785     }
4786   } else {
4787     for (mlir::Value result : fusion_results) {
4788       if (IsReductionFromOrToContiguousDimensions(result.getDefiningOp(),
4789                                                   layout_analysis)) {
4790         // Atomic.add of the reduction result can't be vectorized.
4791         cannot_be_vectorized++;
4792       } else {
4793         // Write of the non-reduction result can be vectorized.
4794         can_be_vectorized++;
4795       }
4796       use_chain_endings.insert(result.getDefiningOp());
4797     }
4798   }
4799   // Fusion inputs that have the same dimension as the reduce input and
4800   // only involve in elementwise operations can be vectorized.
4801   can_be_vectorized += NumInputsInvolveInOnlyElementwiseOps(fusion, input_shape,
4802                                                             use_chain_endings);
4803   // Fusion inputs with more elements than the reduce op input must participate
4804   // in non-elementwise operations and we assume that they are not vectorizable
4805   // for the purpose of estimating the benefit of unrolling. If the kernel is
4806   // unrolled even with such an assumption,  and the accesses to those inputs
4807   // turn out to be vectorizable, the compiler will still vectorize them.
4808   cannot_be_vectorized += NumInputsWithMoreElementsThan(fusion, input_shape);
4809   return can_be_vectorized >= cannot_be_vectorized;
4810 }
4811 
NearestPowerOfTwo(int64_t v)4812 int64 NearestPowerOfTwo(int64_t v) {
4813   if (v < 0) {
4814     return 0;
4815   }
4816   int64_t upper = tensorflow::NextPowerOfTwo64(v);
4817   int64_t lower = upper >> 1;
4818   return upper - v < v - lower ? upper : lower;
4819 }
4820 
4821 }  // namespace
4822 
ComputeReductionCodegenInfo(mlir::Operation * unnested_hlo,mlir::Operation * first_reduce,const FusionLayoutAnalysis & layout_analysis)4823 ReductionCodegenInfo IrEmitterUnnested::ComputeReductionCodegenInfo(
4824     mlir::Operation* unnested_hlo, mlir::Operation* first_reduce,
4825     const FusionLayoutAnalysis& layout_analysis) {
4826   Shape input_shape = layout_analysis.GetShape(first_reduce->getOperand(0));
4827   ReductionDimensions reduction_dimensions =
4828       GetReductionKindAndContiguousComponents(first_reduce);
4829   VLOG(10) << "is_row_reduction " << reduction_dimensions.is_row_reduction
4830            << " " << reduction_dimensions.dimensions[0] << " "
4831            << reduction_dimensions.dimensions[1] << " "
4832            << reduction_dimensions.dimensions[2];
4833   auto get_dtype_bits = [](mlir::Value i) {
4834     // TODO(timshen): may not be efficient.
4835     return primitive_util::BitWidth(GetShape(i).element_type());
4836   };
4837 
4838   // For fusion with multiple inputs, use the smallest input dtype to
4839   // select the reduction_tiling.
4840   int smallest_input_dtype_bits = get_dtype_bits(first_reduce->getOperand(0));
4841 
4842   for (mlir::Value operand : GetHloOperands(unnested_hlo)) {
4843     smallest_input_dtype_bits =
4844         std::min(get_dtype_bits(operand), smallest_input_dtype_bits);
4845   }
4846   std::array<int64, 3> reduction_tiling =
4847       GetReductionTiling(reduction_dimensions, smallest_input_dtype_bits,
4848                          ir_emitter_context_->cuda_compute_capability());
4849 
4850   int64_t num_threads_y = reduction_dimensions.is_row_reduction ? 1 : kWarpSize;
4851   int64_t num_threads_x = [&] {
4852     if (reduction_dimensions.is_row_reduction) {
4853       // Use 512 as default block size (threads per block) for row reductions.
4854       // For multi-output fusions, reduce the block size further to decrease
4855       // register pressure when multiple outputs are computed by each thread.
4856       int64_t fan_out = 1;
4857       if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(unnested_hlo)) {
4858         fan_out = fusion.getFusionResults().size();
4859       }
4860 
4861       int64_t max_block_size =
4862           std::max(kMinThreadsXRowReduction,
4863                    static_cast<int64_t>(512LL / NearestPowerOfTwo(fan_out)));
4864       return std::min(
4865           max_block_size,
4866           RoundUpToNearest(CeilOfRatio(reduction_dimensions.dimensions[2],
4867                                        reduction_tiling[2]),
4868                            kWarpSize));
4869     }
4870     return kWarpSize;
4871   }();
4872 
4873   bool tile_fit = reduction_dimensions.dimensions[kDimX] %
4874                       (reduction_tiling[2] * num_threads_x) ==
4875                   0;
4876   se::CudaComputeCapability cc = ir_emitter_context_->cuda_compute_capability();
4877 
4878   int num_partial_results = 1;
4879   KernelMappingScheme::IndexingOrder indexing_order = [&]() {
4880     if (reduction_dimensions.is_row_reduction &&
4881         // P100, only try to vectorize+coales memory access when the
4882         // tile size fits exactly and dtypes <= 32 bits
4883         ((cc.major == 6 && smallest_input_dtype_bits <= 32 && tile_fit) ||
4884          // On V100, only try to vectorize+coales memory access for
4885          // rows of even size.  For odd row sizes, every other row
4886          // isn't aligned, so it can't be vectorized.
4887          (cc.major >= 7 && reduction_dimensions.dimensions[2] % 2 == 0))) {
4888       return kStridedLinearIndexingX;
4889     } else if (!reduction_dimensions.is_row_reduction &&
4890                IsUnrollingColumnReductionBeneficial(
4891                    unnested_hlo, input_shape,
4892                    reduction_dimensions.dimensions[2], layout_analysis)) {
4893       num_partial_results = 2;
4894       reduction_tiling[2] *= num_partial_results;
4895       return kLinearIndexingX;
4896     } else {
4897       return kStridedIndexingX;
4898     }
4899   }();
4900 
4901   int vector_size = 1;
4902   if (indexing_order == kStridedLinearIndexingX) {
4903     // Assuming XLA will perform the unrolling and LLVM will vectorize,
4904     // disable the unroll for the cases that LLVM doesn't vectorize.
4905     if (reduction_dimensions.dimensions[2] % 2 == 0 &&
4906         !MayPreventVectorization(unnested_hlo)) {
4907       vector_size = 2;
4908     } else {
4909       indexing_order = kStridedIndexingX;
4910     }
4911   }
4912   KernelMappingScheme mapping_scheme(
4913       reduction_dimensions.dimensions,
4914       {reduction_tiling[0], reduction_tiling[1] * num_threads_y,
4915        reduction_tiling[2] * num_threads_x},
4916       num_threads_y, num_threads_x, indexing_order, vector_size);
4917   return ReductionCodegenInfo(
4918       mapping_scheme, num_partial_results,
4919       reduction_dimensions.is_row_reduction,
4920       ReductionIsRaceFree(reduction_dimensions, reduction_tiling));
4921 }
4922 
EmitIRForReduction(mlir::Operation * unnested_hlo,absl::Span<const int> instr_index_group,HloComputation * fused_computation,FusedIrEmitter * fused_emitter,absl::Span<const llvm_ir::IrArray> result_ir_arrays,ReductionCodegenState * reduction_info,const Shape & input_shape,const FusionLayoutAnalysis & layout_analysis)4923 void IrEmitterUnnested::EmitIRForReduction(
4924     mlir::Operation* unnested_hlo, absl::Span<const int> instr_index_group,
4925     HloComputation* fused_computation, FusedIrEmitter* fused_emitter,
4926     absl::Span<const llvm_ir::IrArray> result_ir_arrays,
4927     ReductionCodegenState* reduction_info, const Shape& input_shape,
4928     const FusionLayoutAnalysis& layout_analysis) {
4929   std::vector<HloComputation*> reducers;
4930   for (int index : instr_index_group) {
4931     mlir::Operation* reduce = GetReduceFromUnnestedMlir(unnested_hlo, index);
4932     if (!IsReductionFromOrToContiguousDimensions(reduce, layout_analysis)) {
4933       continue;
4934     }
4935     if (auto nested_reduce = mlir::dyn_cast<mlir::mhlo::ReduceOp>(reduce)) {
4936       HloInstruction* root = fused_computation->root_instruction();
4937       if (root->opcode() == HloOpcode::kTuple) {
4938         root = root->mutable_operand(index);
4939       } else {
4940         CHECK_EQ(0, index);
4941       }
4942       reducers.push_back(root->to_apply());
4943     } else {
4944       LOG(FATAL) << "Unexpected reduce op: " << MlirToString(reduce);
4945     }
4946   }
4947   CHECK(!reducers.empty()) << " expect at least one reduce instructions.";
4948 
4949   const KernelMappingScheme& mapping_scheme =
4950       reduction_info->GetKernelMappingScheme();
4951   LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(),
4952                                      mapping_scheme.GetThreadsPerBlock());
4953   llvm::Type* index_ty = GetIndexTypeForKernel(
4954       unnested_hlo, launch_dimensions.launch_bound(), &b_);
4955   EmitPrologueForReduction(unnested_hlo, instr_index_group, fused_computation,
4956                            fused_emitter, result_ir_arrays, reduction_info,
4957                            layout_analysis);
4958 
4959   EmitElementFunction emit_reduction_tile =
4960       [&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc,
4961           llvm::Value* x_loc, int64_t x_iter_num) {
4962         EmitTileElementForReduction(
4963             unnested_hlo, input_shape, instr_index_group, fused_computation,
4964             fused_emitter, result_ir_arrays, reducers, index, *reduction_info,
4965             x_iter_num, layout_analysis);
4966       };
4967 
4968   TilingKernelInfo tiling_kernel_info = EmitTilingKernel(
4969       mapping_scheme, index_ty,
4970       [&](const ThreadIdInfo& thread_id_info, const IrArray::Index& index,
4971           const string& loop_name, llvm::Value* tile_height,
4972           llvm::Value* tile_width, KernelSupportLibrary* ksl) {
4973         EmitTile(reduction_info->GetKernelMappingScheme(), index, loop_name,
4974                  ksl, thread_id_info, tile_height, tile_width,
4975                  emit_reduction_tile);
4976       });
4977   EmitEpilogueForReduction(index_ty, unnested_hlo, instr_index_group,
4978                            result_ir_arrays, reducers, *reduction_info,
4979                            tiling_kernel_info, layout_analysis);
4980 }
4981 
4982 namespace {
4983 
4984 // Returns whether the `instr` is either a constant, a scalar, or a
4985 // broadcasted constant/scalar.
IsBroadcastedConstantOrScalar(const HloInstruction & instr)4986 bool IsBroadcastedConstantOrScalar(const HloInstruction& instr) {
4987   return instr.IsConstant() || ShapeUtil::IsScalar(instr.shape()) ||
4988          (HloOpcode::kBroadcast == instr.opcode() &&
4989           (instr.operand(0)->IsConstant() ||
4990            ShapeUtil::IsScalar(instr.operand(0)->shape())));
4991 }
4992 
4993 // Divides `num_reduces` reduces into groups. Different groups will be executed
4994 // in parallel. Generally speaking, we'd like to run the reduce instructions
4995 // in parallel without incurring too much recomputation overhead. The current
4996 // heuristic is to place reduce instructions who share nothing or only
4997 // (broadcasted) scalars/constants into different groups; otherwise, they are
4998 // placed in the same group. Non-reduce instructions always go with the reduce
4999 // instructions into the same group so long as they share any predecessors.
GroupDisjointReductions(HloComputation * fused_computation,int num_reduces)5000 std::vector<std::vector<int>> GroupDisjointReductions(
5001     HloComputation* fused_computation, int num_reduces) {
5002   CHECK_NE(0, num_reduces);
5003   if (num_reduces == 1) {
5004     return {{0}};
5005   }
5006 
5007   std::vector<tensorflow::UnionFind<HloInstruction*>> disjoint_sets(
5008       num_reduces);
5009   for (size_t i = 0; i < num_reduces; ++i) {
5010     disjoint_sets[i].Get() =
5011         fused_computation->root_instruction()->mutable_operand(i);
5012   }
5013 
5014   std::unique_ptr<HloReachabilityMap> reachability_map =
5015       HloReachabilityMap::Build(fused_computation);
5016   for (HloInstruction* instr : fused_computation->instructions()) {
5017     std::vector<int64> reached_output_ids;
5018     for (size_t oid = 0; oid < num_reduces; ++oid) {
5019       HloInstruction* reduce =
5020           fused_computation->root_instruction()->mutable_operand(oid);
5021       if (HloOpcode::kReduce == reduce->opcode() &&
5022           (IsBroadcastedConstantOrScalar(*instr))) {
5023         // Do not group output reduce instructions through broadcasted
5024         // constants or scalars, as the recomputation should be acceptable.
5025         VLOG(3) << "Skip broadcasted constant or scalar " << instr->ToString();
5026         continue;
5027       }
5028       // Now group output instructions if they have common predecessors.
5029       if (reachability_map->IsReachable(instr, reduce)) {
5030         VLOG(3) << "Reaching " << reduce->ToString() << " from "
5031                 << instr->ToString();
5032         reached_output_ids.push_back(oid);
5033       }
5034     }
5035     for (size_t j = 1; j < reached_output_ids.size(); ++j) {
5036       disjoint_sets[reached_output_ids[0]].Merge(
5037           &disjoint_sets[reached_output_ids[j]]);
5038     }
5039   }
5040   // Place output instructions in the same set into the same group.
5041   HloInstructionMap<std::vector<int>> groups;
5042   for (size_t oid = 0; oid < num_reduces; ++oid) {
5043     groups[disjoint_sets[oid].Get()].push_back(oid);
5044   }
5045 
5046   std::vector<std::vector<int>> ret;
5047   absl::c_for_each(
5048       groups, [&](auto& iter) { ret.emplace_back(std::move(iter.second)); });
5049   return ret;
5050 }
5051 
5052 }  // namespace
5053 
EmitUnnestedReduction(mlir::Operation * unnested_hlo,const FusionLayoutAnalysis & layout_analysis)5054 Status IrEmitterUnnested::EmitUnnestedReduction(
5055     mlir::Operation* unnested_hlo,
5056     const FusionLayoutAnalysis& layout_analysis) {
5057   auto fusion = mlir::cast<mlir::lmhlo::FusionOp>(unnested_hlo);
5058 
5059   int num_reduces = fusion.getFusionResults().size();
5060 
5061   // Build a kernel thunk to compute all the outputs.
5062   mlir::Operation* first_reduce = nullptr;
5063   for (int i = 0; i < num_reduces; ++i) {
5064     mlir::Operation* output_instruction =
5065         GetReduceFromUnnestedMlir(unnested_hlo, i);
5066     if (IsReductionFromOrToContiguousDimensions(output_instruction,
5067                                                 layout_analysis)) {
5068       first_reduce = GetReduceFromUnnestedMlir(unnested_hlo, i);
5069       break;
5070     }
5071   }
5072   CHECK(first_reduce) << MlirToString(unnested_hlo);
5073   if (num_reduces > 1) {
5074     for (int i = 0; i < num_reduces; i++) {
5075       auto candidate = mlir::dyn_cast<mlir::mhlo::ReduceOp>(
5076           GetReduceFromUnnestedMlir(unnested_hlo, i));
5077       if (candidate &&
5078           !IsFusedReductionOutputConsistent(
5079               candidate, mlir::cast<mlir::mhlo::ReduceOp>(first_reduce),
5080               layout_analysis)) {
5081         return InternalError("Inconsistent reduction fusion outputs");
5082       }
5083     }
5084   }
5085   Shape input_shape = GetShape(first_reduce->getOperand(0));
5086   // The layout of a reduction input is either set by LayoutAssignment for
5087   // unnested kReduce or by InstructionFusion for fused kReduce.
5088   CHECK(input_shape.has_layout()) << "LayoutAssignment or InstructionFusion "
5089                                      "doesn't set the input layout of "
5090                                   << MlirToString(first_reduce);
5091 
5092   HloComputation* fused_computation = nullptr;
5093   TF_ASSIGN_OR_RETURN(fused_computation,
5094                       GetOrCreateSubComputationFromRegion(&fusion.region(),
5095                                                           /*is_fusion=*/true));
5096 
5097   // Group disjoint reductions in groups, to be executed in parallel.
5098   std::vector<std::vector<int>> instr_index_groups =
5099       GroupDisjointReductions(fused_computation, num_reduces);
5100 
5101   VLOG(2) << StrCat("Generate in ", instr_index_groups.size(), " groups for ",
5102                     MlirToString(unnested_hlo));
5103 
5104   ReductionCodegenInfo reduction_info =
5105       ComputeReductionCodegenInfo(unnested_hlo, first_reduce, layout_analysis);
5106   const KernelMappingScheme& mapping_scheme =
5107       reduction_info.GetKernelMappingScheme();
5108 
5109   // block_y_count is set to instr_index_groups.size(), so that each reduction
5110   // group can be run in parallel by a different BlockIdy.
5111   LaunchDimensions launch_dimensions(
5112       {/*x=*/mapping_scheme.GetNumberOfBlocks(),
5113        /*y=*/static_cast<int64>(instr_index_groups.size()),
5114        /*z=*/1},
5115       {/*x=*/mapping_scheme.GetThreadsPerBlock(), /*y=*/1, /*z=*/1});
5116   VLOG(3) << "Launch dimensions of "
5117           << mlir::GetNameFromLoc(unnested_hlo->getLoc())
5118           << ": number of blocks: " << mapping_scheme.GetNumberOfBlocks()
5119           << " - threads per block: " << mapping_scheme.GetThreadsPerBlock();
5120 
5121   std::vector<llvm_ir::IrArray> ir_arrays;
5122   TF_ASSIGN_OR_RETURN(std::unique_ptr<KernelThunk> kernel_thunk,
5123                       BuildKernelThunk(unnested_hlo, Thunk::ThunkInfo(),
5124                                        &ir_arrays, launch_dimensions));
5125 
5126   GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
5127                                           ir_emitter_context_->llvm_module(),
5128                                           &b_, GetNestedComputer());
5129   FusedIrEmitter fused_emitter(&elemental_emitter);
5130 
5131   CHECK_LT(fused_computation->num_parameters(), ir_arrays.size());
5132   for (int i = 0; i < fused_computation->num_parameters(); i++) {
5133     llvm_ir::IrArray ir_array = ir_arrays[i];
5134     HloInstruction* fused_operand = fused_computation->parameter_instruction(i);
5135     fused_emitter.BindGenerator(
5136         fused_operand,
5137         [this, ir_array, fused_operand](const llvm_ir::IrArray::Index& index) {
5138           return ir_array.EmitReadArrayElement(index, &b_,
5139                                                fused_operand->name());
5140         });
5141   }
5142   absl::Span<const llvm_ir::IrArray> result_ir_arrays =
5143       absl::MakeSpan(ir_arrays).subspan(fused_computation->num_parameters(),
5144                                         num_reduces);
5145 
5146   // We always use the first reduce as representative to construct
5147   // ReductionCodegenInfo, since all the reductions are required to have the
5148   // same shape and layout as verified by `IsFusedReductionOutputConsistent()`.
5149   ReductionCodegenInfo reduction_codegen_info =
5150       ComputeReductionCodegenInfo(unnested_hlo, first_reduce, layout_analysis);
5151 
5152   KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
5153   for (size_t i = 0; i < instr_index_groups.size(); ++i) {
5154     // Create a new ReductionCodegenInfo instance as it contains states for
5155     // code generation per reduction group.
5156     ReductionCodegenState reduction_info =
5157         ReductionCodegenState(reduction_codegen_info);
5158     // Use raw block_id_y to select the i-th parallel reduction to run. Using
5159     // block_id_y instead of block_id_x simplifies the index calculation
5160     // for reduction code generation as the block_id_y is orthogonal to
5161     // the indices used within the reductions.
5162     llvm::CallInst* raw_block_id_y = gpu::EmitCallToTargetIntrinsic(
5163         gpu::TargetIntrinsicID::kBlockIdy, {}, {}, &b_);
5164     llvm_ir::AddRangeMetadata(0, instr_index_groups.size(),
5165                               llvm::cast<llvm::Instruction>(raw_block_id_y));
5166     ksl.If(StrCat("reduce-group-", i),
5167            b_.CreateICmpEQ(raw_block_id_y, b_.getInt32(i)), [&] {
5168              EmitIRForReduction(unnested_hlo, instr_index_groups[i],
5169                                 fused_computation, &fused_emitter,
5170                                 result_ir_arrays, &reduction_info, input_shape,
5171                                 layout_analysis);
5172            });
5173   }
5174 
5175   if (hlo_module_config_.debug_options().xla_gpu_deterministic_reductions() &&
5176       !reduction_codegen_info.IsRaceFree()) {
5177     return InternalError(
5178         "All reductions should be race-free if deterministic reductions are "
5179         "enabled");
5180   }
5181 
5182   // Build an initializer thunk to initialize each reduction output.
5183   ThunkSequence thunks;
5184   for (int i = 0; i < num_reduces; ++i) {
5185     mlir::Operation* output_instruction =
5186         GetReduceFromUnnestedMlir(unnested_hlo, i);
5187     if (!IsReductionFromOrToContiguousDimensions(output_instruction,
5188                                                  layout_analysis)) {
5189       // Elemental IR emitter is used.
5190       continue;
5191     } else if (reduction_codegen_info.IsRaceFree()) {
5192       VLOG(5) << "We do not need initialization: using tree reductions";
5193       continue;
5194     }
5195 
5196     TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> initializer_thunk,
5197                         BuildFusedInitializerThunk(fusion, i));
5198     thunks.push_back(std::move(initializer_thunk));
5199   }
5200 
5201   thunks.push_back(std::move(kernel_thunk));
5202   auto sequential_thunk = absl::make_unique<SequentialThunk>(
5203       GetThunkInfo(unnested_hlo), std::move(thunks));
5204   AddThunkToThunkSequence(std::move(sequential_thunk));
5205 
5206   return Status::OK();
5207 }
5208 
5209 // Emits code for slices based on the below structure. An if statement with
5210 // a guarding condition is generated for each ROOT slice.
5211 //
5212 // Pseudo code:
5213 //
5214 // Compute values of slice input operands
5215 //
5216 // Compute guarding_cond0
5217 // if (guarding_cond0) {
5218 //   Write to output of slice0
5219 // }
5220 //
5221 // Compute guarding_cond1
5222 // if (guarding_cond1) {
5223 //   Write to output of slice1
5224 // }
5225 //
EmitElementForInputFusibleSlices(const HloComputation * fused_computation,absl::Span<const llvm_ir::IrArray> ir_arrays,const llvm_ir::IrArray::Index & index)5226 Status IrEmitterUnnested::EmitElementForInputFusibleSlices(
5227     const HloComputation* fused_computation,
5228     absl::Span<const llvm_ir::IrArray> ir_arrays,
5229     const llvm_ir::IrArray::Index& index) {
5230   VLOG(10) << "Emitting slice input fusion for "
5231            << fused_computation->ToString();
5232 
5233   HloInstruction* slice_or_tuple = fused_computation->root_instruction();
5234   auto slice_instructions = [&]() -> absl::Span<HloInstruction* const> {
5235     if (slice_or_tuple->opcode() == HloOpcode::kSlice) {
5236       return absl::Span<HloInstruction* const>(&slice_or_tuple, 1);
5237     }
5238     CHECK_EQ(slice_or_tuple->opcode(), HloOpcode::kTuple);
5239     return slice_or_tuple->operands();
5240   }();
5241 
5242   // Emit input operand values of slices.
5243   std::vector<llvm::Value*> input_ir_values;
5244   GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
5245                                      GetNestedComputer());
5246   FusedIrEmitter fused_emitter(&elem_emitter);
5247   for (int i = 0; i < fused_computation->num_parameters(); i++) {
5248     fused_emitter.BindGenerator(
5249         fused_computation->parameter_instruction(i),
5250         [this, &ir_arrays, i](llvm_ir::IrArray::Index index) {
5251           return ir_arrays[i].EmitReadArrayElement(index, &b_);
5252         });
5253   }
5254   for (const HloInstruction* slice : slice_instructions) {
5255     auto input_generator = *fused_emitter.GetGenerator(slice->operand(0));
5256     input_ir_values.push_back(input_generator(index).ValueOrDie());
5257   }
5258 
5259   // Emit for slice_instructions.
5260   KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
5261   for (int64_t i = 0; i < slice_instructions.size(); ++i) {
5262     HloInstruction* slice = slice_instructions[i];
5263 
5264     // guarding_cond := index >= start && index < limit, for each dim.
5265     std::vector<llvm::Value*> index_within_ranges;
5266     for (size_t dim = 0; dim < slice->slice_starts().size(); ++dim) {
5267       CHECK_EQ(slice->slice_strides(dim), 1);
5268       auto larger_or_equal_than_start = b_.CreateICmpSGE(
5269           index.multidim()[dim],
5270           index.GetConstantWithIndexType(slice->slice_starts(dim)));
5271       llvm::Value* smaller_than_limit = b_.CreateICmpSLT(
5272           index.multidim()[dim],
5273           index.GetConstantWithIndexType(slice->slice_limits(dim)));
5274       llvm::Value* within_range =
5275           b_.CreateAnd(larger_or_equal_than_start, smaller_than_limit);
5276       index_within_ranges.push_back(within_range);
5277     }
5278     llvm::Value* guarding_cond = b_.CreateAnd(index_within_ranges);
5279 
5280     auto emit_slice_elem_func = [&] {
5281       const std::vector<llvm::Value*>& src_multidim = index.multidim();
5282       std::vector<llvm::Value*> dst_multidim(src_multidim.size());
5283       for (size_t dim = 0; dim < src_multidim.size(); ++dim) {
5284         dst_multidim[dim] =
5285             Sub(src_multidim[dim],
5286                 index.GetConstantWithIndexType(slice->slice_starts(dim)));
5287       }
5288       llvm_ir::IrArray src_ir_array =
5289           ir_arrays[fused_computation->num_parameters() + i];
5290       IrArray::Index slice_dst_index(dst_multidim, slice->shape(),
5291                                      index.GetType());
5292       src_ir_array.EmitWriteArrayElement(slice_dst_index, input_ir_values[i],
5293                                          &b_);
5294     };
5295 
5296     ksl.If(StrCat("slice", i), guarding_cond, emit_slice_elem_func);
5297   }
5298   return Status::OK();
5299 }
5300 
EmitInputFusibleNonStridedSlices(mlir::Operation * op)5301 Status IrEmitterUnnested::EmitInputFusibleNonStridedSlices(
5302     mlir::Operation* op) {
5303   auto fusion = mlir::cast<mlir::lmhlo::FusionOp>(op);
5304 
5305   constexpr int unroll_factor = 1;
5306 
5307   TF_ASSIGN_OR_RETURN(const HloComputation* fused_computation,
5308                       GetOrCreateSubComputationFromRegion(&fusion.region(),
5309                                                           /*is_fusion=*/true));
5310 
5311   TF_ASSIGN_OR_RETURN(Shape element_shape,
5312                       GetConsistentInputShapeForRootSlices(fused_computation));
5313   TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions,
5314                       CalculateLaunchDimensions(
5315                           element_shape, ir_emitter_context_->gpu_device_info(),
5316                           {unroll_factor}));
5317 
5318   std::vector<llvm_ir::IrArray> ir_arrays;
5319   TF_ASSIGN_OR_RETURN(auto kernel_thunk,
5320                       BuildKernelThunk(fusion, GetThunkInfo(op), &ir_arrays,
5321                                        launch_dimensions));
5322 
5323   Status emit_status =
5324       ParallelLoopEmitter(
5325           [&](const llvm_ir::IrArray::Index index) -> Status {
5326             return EmitElementForInputFusibleSlices(fused_computation,
5327                                                     ir_arrays, index);
5328           },
5329           element_shape, launch_dimensions, &b_)
5330           .EmitLoop(IrName(mlir::GetNameFromLoc(fusion.getLoc())),
5331                     GetIndexTypeForKernel(
5332                         fusion, launch_dimensions.launch_bound(), &b_));
5333 
5334   thunk_sequence_.emplace_back(std::move(kernel_thunk));
5335 
5336   return emit_status;
5337 }
5338 
EmitOp(mlir::Operation * op)5339 Status IrEmitterUnnested::EmitOp(mlir::Operation* op) {
5340   if (mlir::isa<mlir::ConstantOp, mlir::memref::ViewOp,
5341                 mlir::memref::ReinterpretCastOp, mlir::ReturnOp,
5342                 mlir::lmhlo::TerminatorOp>(op)) {
5343     return Status::OK();
5344   }
5345 
5346   if (mlir::isa<mlir::memref::GetGlobalOp>(op)) {
5347     return EmitConstant(op);
5348   }
5349 
5350   if (auto call = mlir::dyn_cast<mlir::lmhlo::CustomCallOp>(op)) {
5351     if (call.call_target_name() == "PadToStatic") {
5352       return EmitPadToStatic(op);
5353     }
5354     if (call.call_target_name() == "SliceToDynamic") {
5355       return EmitSliceToDynamic(op);
5356     }
5357     return EmitCustomCallThunk(op);
5358   }
5359 
5360   if (mlir::isa<mlir::lmhlo_gpu::GEMMOp, mlir::lmhlo_gpu::GEMM_BiasOp>(op)) {
5361     return EmitGemmThunk(op);
5362   }
5363 
5364   if (mlir::isa<mlir::lmhlo_gpu::ConvForwardOp,
5365                 mlir::lmhlo_gpu::ConvForwardFusedOp,
5366                 mlir::lmhlo_gpu::ConvForwardFusedSideInputOp,
5367                 mlir::lmhlo_gpu::ConvBackwardFilterOp,
5368                 mlir::lmhlo_gpu::ConvBackwardInputOp>(op)) {
5369     return EmitConvolutionThunk(op);
5370   }
5371 
5372   if (mlir::isa<mlir::lmhlo_gpu::BatchNormTrainingOp,
5373                 mlir::lmhlo_gpu::BatchNormInferenceOp,
5374                 mlir::lmhlo_gpu::BatchNormGradOp>(op)) {
5375     return EmitBatchNormThunk(op);
5376   }
5377 
5378 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
5379   if (mlir::isa<mlir::lmhlo_gpu::CholeskyOp>(op)) {
5380     return EmitCholeskyThunk(op);
5381   }
5382 #endif  // GOOGLE_CUDA
5383 
5384   if (mlir::isa<mlir::lmhlo::FftOp>(op)) {
5385     return EmitFftThunk(op);
5386   }
5387 
5388   if (mlir::isa<mlir::lmhlo::TriangularSolveOp>(op)) {
5389     return EmitTriangularSolve(op);
5390   }
5391 
5392   if (mlir::isa<mlir::lmhlo::FusionOp>(op)) {
5393     return EmitFusion(op);
5394   }
5395 
5396   if (mlir::isa<mlir::lmhlo::SelectAndScatterOp>(op)) {
5397     return EmitSelectAndScatter(op);
5398   }
5399 
5400   if (mlir::isa<mlir::lmhlo::RngGetAndUpdateStateOp>(op)) {
5401     return EmitRngGetAndUpdateState(op);
5402   }
5403 
5404   if (mlir::isa<mlir::lmhlo::ScatterOp>(op)) {
5405     return EmitScatter(op);
5406   }
5407 
5408   if (mlir::isa<mlir::lmhlo::SortOp>(op)) {
5409     return EmitSort(op);
5410   }
5411 
5412   if (mlir::isa<mlir::lmhlo::ReplicaIdOp>(op)) {
5413     return EmitReplicaOrPartitionId<ReplicaIdThunk, mlir::lmhlo::ReplicaIdOp>(
5414         op);
5415   }
5416 
5417   if (mlir::isa<mlir::lmhlo::PartitionIdOp>(op)) {
5418     return EmitReplicaOrPartitionId<PartitionIdThunk,
5419                                     mlir::lmhlo::PartitionIdOp>(op);
5420   }
5421 
5422   if (mlir::isa<mlir::lmhlo::CollectivePermuteOp>(op)) {
5423     return EmitCollectivePermute(op);
5424   }
5425 
5426   if (mlir::isa<mlir::lmhlo::AllGatherOp>(op)) {
5427     return EmitNcclThunk<NcclAllGatherThunk, mlir::lmhlo::AllGatherOp>(op);
5428   }
5429 
5430   if (mlir::isa<mlir::lmhlo::AllReduceOp>(op)) {
5431     return EmitNcclThunk<NcclAllReduceThunk, mlir::lmhlo::AllReduceOp>(op);
5432   }
5433 
5434   if (mlir::isa<mlir::lmhlo_gpu::AllReduceStartOp>(op)) {
5435     return EmitNcclThunk<NcclAllReduceStartThunk,
5436                          mlir::lmhlo_gpu::AllReduceStartOp>(op);
5437   }
5438 
5439   if (mlir::isa<mlir::lmhlo_gpu::AllReduceDoneOp>(op)) {
5440     return EmitAllReduceDone(op);
5441   }
5442 
5443   if (mlir::isa<mlir::lmhlo::ReduceScatterOp>(op)) {
5444     return EmitNcclThunk<NcclReduceScatterThunk, mlir::lmhlo::ReduceScatterOp>(
5445         op);
5446   }
5447 
5448   if (mlir::isa<mlir::lmhlo::AllToAllOp>(op)) {
5449     return EmitNcclThunk<NcclAllToAllThunk, mlir::lmhlo::AllToAllOp>(op);
5450   }
5451 
5452   if (mlir::isa<mlir::lmhlo::InfeedOp>(op)) {
5453     return EmitInfeed(op);
5454   }
5455 
5456   if (mlir::isa<mlir::lmhlo::OutfeedOp>(op)) {
5457     return EmitOutfeed(op);
5458   }
5459 
5460   if (mlir::isa<mlir::lmhlo::CaseOp>(op)) {
5461     return EmitConditional(op);
5462   }
5463 
5464   if (mlir::isa<mlir::lmhlo::WhileOp>(op)) {
5465     return EmitWhile(op);
5466   }
5467 
5468   return InternalError("Unrecognized op: %s", MlirToString(op));
5469 }
5470 
EmitLmhloRegion(mlir::Region * region)5471 Status IrEmitterUnnested::EmitLmhloRegion(mlir::Region* region) {
5472   for (mlir::Operation& op : llvm::make_early_inc_range(region->front())) {
5473     TF_RETURN_IF_ERROR(EmitOp(&op));
5474   }
5475   return Status::OK();
5476 }
5477 
GetThunkInfo(mlir::Operation * op)5478 Thunk::ThunkInfo IrEmitterUnnested::GetThunkInfo(mlir::Operation* op) {
5479   auto module = op->getParentOfType<mlir::ModuleOp>();
5480   Thunk::ThunkInfo thunk_info;
5481   thunk_info.profile_annotation = absl::StrFormat(
5482       "Thunk:#hlo_op=%s,hlo_module=%s#", mlir::GetNameFromLoc(op->getLoc()),
5483       mlir::GetNameFromLoc(module->getLoc()));
5484   return thunk_info;
5485 }
5486 
SetOperation(mlir::Operation * op)5487 void MlirEmitterContext::SetOperation(mlir::Operation* op) {
5488   this->name = mlir::GetNameFromLoc(op->getLoc());
5489 
5490   auto operands = GetHloOperands(op);
5491   auto outputs = GetHloOutputs(op);
5492   for (auto operand : operands) {
5493     operand_shapes.push_back(GetShape(operand));
5494   }
5495   for (auto output : outputs) {
5496     output_shapes.push_back(GetShape(output));
5497   }
5498 }
5499 
5500 }  // namespace gpu
5501 }  // namespace xla
5502