1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
17
18 #include <algorithm>
19 #include <cstring>
20 #include <iterator>
21 #include <memory>
22 #include <string>
23 #include <type_traits>
24 #include <vector>
25
26 #include "absl/algorithm/container.h"
27 #include "absl/container/inlined_vector.h"
28 #include "absl/memory/memory.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/strings/str_format.h"
31 #include "absl/types/optional.h"
32 #include "absl/types/span.h"
33 #include "llvm/ADT/APInt.h"
34 #include "llvm/ADT/SetVector.h"
35 #include "llvm/ADT/StringRef.h"
36 #include "llvm/IR/BasicBlock.h"
37 #include "llvm/IR/Function.h"
38 #include "llvm/IR/IRBuilder.h"
39 #include "llvm/IR/Instructions.h"
40 #include "llvm/IR/LLVMContext.h"
41 #include "llvm/IR/Module.h"
42 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
43 #include "mlir/IR/Attributes.h" // from @llvm-project
44 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
45 #include "mlir/IR/Builders.h" // from @llvm-project
46 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
47 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
48 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
49 #include "mlir/IR/Verifier.h" // from @llvm-project
50 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
51 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
52 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h"
53 #include "tensorflow/compiler/mlir/utils/name_utils.h"
54 #include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
55 #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
56 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
57 #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
58 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
59 #include "tensorflow/compiler/xla/layout_util.h"
60 #include "tensorflow/compiler/xla/literal.h"
61 #include "tensorflow/compiler/xla/permutation_util.h"
62 #include "tensorflow/compiler/xla/service/collective_ops_utils.h"
63 #include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
64 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
65 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
66 #include "tensorflow/compiler/xla/service/gpu/bef_thunk.h"
67 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
68 #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h"
69 #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
70 #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h"
71 #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h"
72 #include "tensorflow/compiler/xla/service/gpu/custom_call_thunk.h"
73 #include "tensorflow/compiler/xla/service/gpu/fft_thunk.h"
74 #include "tensorflow/compiler/xla/service/gpu/for_thunk.h"
75 #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h"
76 #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
77 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"
78 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
79 #include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h"
80 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
81 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
82 #include "tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h"
83 #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h"
84 #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
85 #include "tensorflow/compiler/xla/service/gpu/memset_thunk.h"
86 #include "tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.h"
87 #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
88 #include "tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.h"
89 #include "tensorflow/compiler/xla/service/gpu/nccl_collective_permute_thunk.h"
90 #include "tensorflow/compiler/xla/service/gpu/outfeed_thunk.h"
91 #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
92 #include "tensorflow/compiler/xla/service/gpu/replica_id_thunk.h"
93 #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
94 #include "tensorflow/compiler/xla/service/gpu/target_util.h"
95 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
96 #include "tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h"
97 #include "tensorflow/compiler/xla/service/gpu/while_thunk.h"
98 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
99 #include "tensorflow/compiler/xla/service/hlo_computation.h"
100 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
101 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
102 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
103 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
104 #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
105 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
106 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
107 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
108 #include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h"
109 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
110 #include "tensorflow/compiler/xla/service/name_uniquer.h"
111 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
112 #include "tensorflow/compiler/xla/service/shape_inference.h"
113 #include "tensorflow/compiler/xla/service/while_loop_analysis.h"
114 #include "tensorflow/compiler/xla/shape_util.h"
115 #include "tensorflow/compiler/xla/status_macros.h"
116 #include "tensorflow/compiler/xla/types.h"
117 #include "tensorflow/compiler/xla/union_find.h"
118 #include "tensorflow/compiler/xla/util.h"
119 #include "tensorflow/compiler/xla/window_util.h"
120 #include "tensorflow/compiler/xla/xla_data.pb.h"
121 #include "tensorflow/core/lib/core/bits.h"
122 #include "tensorflow/core/lib/core/status.h"
123 #include "tensorflow/core/platform/errors.h"
124 #include "tensorflow/core/platform/logging.h"
125
126 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
127 #include "tensorflow/compiler/xla/service/gpu/cholesky_thunk.h"
128 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
129
130 namespace xla {
131 namespace gpu {
132
133 namespace {
134
135 using absl::InlinedVector;
136 using absl::nullopt;
137 using absl::optional;
138 using absl::StrCat;
139 using llvm_ir::IrArray;
140 using llvm_ir::IrName;
141
142 const auto kDimX = KernelMappingScheme::DimX;
143 const auto kDimY = KernelMappingScheme::DimY;
144 const auto kDimZ = KernelMappingScheme::DimZ;
145 const auto kDimTot = KernelMappingScheme::DimTot;
146
147 const auto kLinearIndexingX = KernelMappingScheme::LinearIndexingX;
148 const auto kStridedIndexingX = KernelMappingScheme::StridedIndexingX;
149 const auto kStridedLinearIndexingX =
150 KernelMappingScheme::StridedLinearIndexingX;
151
152 // If a dimensions is smaller than this, untiled transposition may be more
153 // efficient.
154 const int64_t kMinDimensionToTransposeTiled = 16;
155
156 // Annotates the launch dimensions of the corresponding IR kernel in
157 // `llvm_module`.
AnnotateThunkLaunchDimensions(const LaunchDimensions & launch_dims,const std::string & kernel_name,llvm::Module * llvm_module)158 void AnnotateThunkLaunchDimensions(const LaunchDimensions& launch_dims,
159 const std::string& kernel_name,
160 llvm::Module* llvm_module) {
161 // Add __launch_bounds__ to metadata. This limits registers per thread to
162 // avoid out-of-resources launching errors.
163 llvm::NamedMDNode* nvvm_annotations_node =
164 llvm_module->getOrInsertNamedMetadata("nvvm.annotations");
165 llvm::Function* ir_kernel = llvm_module->getFunction(kernel_name.c_str());
166 llvm::LLVMContext& llvm_context = llvm_module->getContext();
167 llvm::ConstantInt* threads_per_block_ir_value = llvm::ConstantInt::get(
168 llvm::IntegerType::get(llvm_context, /*NumBits=*/32),
169 launch_dims.thread_counts_per_block().x);
170 // Our launch bounds are exact, so we can specify them as reqntidx rather than
171 // maxntidx.
172 nvvm_annotations_node->addOperand(llvm::MDNode::get(
173 llvm_context,
174 {llvm::ConstantAsMetadata::get(ir_kernel),
175 llvm::MDString::get(llvm_context, "reqntidx"),
176 llvm::ConstantAsMetadata::get(threads_per_block_ir_value)}));
177 }
178
BinarySearchDenseElementsAttr(mlir::DenseIntElementsAttr elements,int64_t v)179 bool BinarySearchDenseElementsAttr(mlir::DenseIntElementsAttr elements,
180 int64_t v) {
181 mlir::APInt value(sizeof(int64) * 8, v, /*isSigned=*/true);
182 return std::binary_search(
183 elements.begin(), elements.end(), value,
184 [](const mlir::APInt& x, const mlir::APInt& y) { return x.slt(y); });
185 }
186
MhloOpIsElementwise(mlir::Operation * op)187 bool MhloOpIsElementwise(mlir::Operation* op) {
188 CHECK(op->getDialect() ==
189 op->getContext()->getLoadedDialect<mlir::mhlo::MhloDialect>());
190 auto opcode = *MhloToHloOpcode(op);
191 if (HloInstruction::IsOpElementwise(opcode)) {
192 return true;
193 }
194 if (opcode == HloOpcode::kMap) {
195 int iota = 0;
196 for (const llvm::APInt& i :
197 mlir::cast<mlir::mhlo::MapOp>(op).dimensions()) {
198 if (i.getZExtValue() != iota) {
199 return false;
200 }
201 iota++;
202 }
203 return true;
204 }
205 // TODO(timshen): not sure about whether porting
206 // HloFusionInstruction::IsElementwiseImpl() is necessary. HandleFusion()
207 // doesn't use such information.
208 return false;
209 }
210
IsSingleInstructionFusion(mlir::lmhlo::FusionOp fusion)211 bool IsSingleInstructionFusion(mlir::lmhlo::FusionOp fusion) {
212 int instruction_count = 0;
213 for (mlir::Operation& instr : fusion.region().front()) {
214 if (mlir::isa<mlir::lmhlo::TerminatorOp, mlir::mhlo::ReturnOp,
215 mlir::memref::TensorLoadOp, mlir::memref::TensorStoreOp>(
216 &instr)) {
217 continue;
218 }
219 instruction_count++;
220 }
221 return instruction_count == 1;
222 }
223
MayPreventVectorization(mlir::Operation * op)224 bool MayPreventVectorization(mlir::Operation* op) {
225 // An empirically chosen constant: unrolling concat with a large amount of
226 // arguments causes excessive register spilling.
227 static constexpr int kMaxConcatArgumentsForUnrolling = 10;
228
229 auto fusion = mlir::cast<mlir::lmhlo::FusionOp>(op);
230 const bool is_single_instruction = IsSingleInstructionFusion(fusion);
231
232 for (mlir::Operation& instr : fusion.region().front()) {
233 if (mlir::isa<mlir::lmhlo::TerminatorOp, mlir::mhlo::ReturnOp,
234 mlir::memref::TensorLoadOp, mlir::memref::TensorStoreOp>(
235 &instr)) {
236 continue;
237 }
238 if (is_single_instruction) {
239 auto instr_opcode = *MhloToHloOpcode(&instr);
240 if (MhloOpIsElementwise(&instr)) {
241 switch (instr_opcode) {
242 case HloOpcode::kSin:
243 case HloOpcode::kCos:
244 case HloOpcode::kPower:
245 case HloOpcode::kAtan2:
246 return true;
247 default:
248 return false;
249 }
250 } else if (instr_opcode == HloOpcode::kReduce &&
251 instr.getNumResults() == 1) {
252 // TODO(timshen): check if the to_apply() attribute contains
253 // instructions that break LLVM vectorization.
254 return false;
255 }
256 return true;
257 }
258
259 CHECK(instr.getDialect() ==
260 instr.getContext()->getLoadedDialect<mlir::mhlo::MhloDialect>())
261 << MlirToString(op);
262 switch (*MhloToHloOpcode(&instr)) {
263 case HloOpcode::kReduceWindow:
264 case HloOpcode::kSort:
265 case HloOpcode::kDot:
266 case HloOpcode::kSin:
267 case HloOpcode::kCos:
268 case HloOpcode::kPower:
269 case HloOpcode::kAtan2:
270 return true;
271 case HloOpcode::kConcatenate:
272 if (instr.getOperands().size() > kMaxConcatArgumentsForUnrolling) {
273 return true;
274 }
275 break;
276 case HloOpcode::kReduce:
277 if (instr.getNumResults() > 1) {
278 return true;
279 }
280 break;
281 default:
282 break;
283 }
284 }
285 return false;
286 }
287
GetOutputOps(mlir::lmhlo::FusionOp fusion)288 std::vector<mlir::Operation*> GetOutputOps(mlir::lmhlo::FusionOp fusion) {
289 llvm::SetVector<mlir::Operation*> ops;
290 for (mlir::Value output_value : fusion.getFusionResults()) {
291 ops.insert(output_value.getDefiningOp());
292 }
293 return std::vector<mlir::Operation*>(ops.begin(), ops.end());
294 }
295
296 // Computes the maximum valid unroll factor for a given instruction.
ComputeMaxUnrollFactor(mlir::Type type,const HloModuleConfig & hlo_module_config)297 int ComputeMaxUnrollFactor(mlir::Type type,
298 const HloModuleConfig& hlo_module_config) {
299 int max_unroll_factor =
300 hlo_module_config.debug_options().xla_gpu_max_kernel_unroll_factor();
301
302 // Find the largest possible power of two to unroll by.
303 // TODO(kramerb): Make this smarter.
304
305 auto shaped_type = type.cast<mlir::ShapedType>();
306 int64_t num_elements = std::accumulate(shaped_type.getShape().begin(),
307 shaped_type.getShape().end(), int64{1},
308 std::multiplies<int64>());
309 for (int i = max_unroll_factor; i > 1; i /= 2) {
310 if (num_elements % i == 0) {
311 return i;
312 }
313 }
314
315 // Cannot unroll.
316 return 1;
317 }
318
319 // Computes the maximum valid unroll factor for a given instruction.
ComputeMaxUnrollFactor(mlir::Operation * op,const HloModuleConfig & hlo_module_config)320 int ComputeMaxUnrollFactor(mlir::Operation* op,
321 const HloModuleConfig& hlo_module_config) {
322 mlir::Type element_shape = [&] {
323 std::vector<mlir::Type> shapes;
324 // Detect multi-output fusion. Notice that for a reduce in the fusion that
325 // returns a tuple, we don't want to treat it as multi-output fusion. We
326 // want to pass that tuple into ComputeMaxUnrollFactor below. For an actual
327 // MOF, just pass the first element of the root tuple.
328 if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
329 std::vector<mlir::Operation*> fusion_outputs = GetOutputOps(fusion);
330 for (mlir::Value result : fusion_outputs[0]->getResults()) {
331 return result.getType();
332 }
333 } else {
334 for (mlir::Value result : GetHloOutputs(op)) {
335 return result.getType();
336 }
337 }
338 CHECK(false);
339 }();
340 return ComputeMaxUnrollFactor(element_shape, hlo_module_config);
341 }
342
343 // Returns the llvm type for the indices used in the kernel that contains the
344 // hlo instruction. Such indices include the index for the parallel loop and
345 // the indices for the tensors accessed by the kernel. The return type is i32
346 // iff the following conditions are met:
347 // . The launch_size of the kernel is within the range of i32.
348 // . The sizes of all the tensors accessed within the kernel are within the
349 // range of i32.
350 // Otherwise, the return type is i64.
GetIndexTypeForKernel(const HloInstruction * hlo,int64_t launch_size,llvm::IRBuilder<> * b)351 llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo,
352 int64_t launch_size, llvm::IRBuilder<>* b) {
353 // Find the unnested hlo instruction for which the kernel is generated for.
354 const HloInstruction* unnested_hlo = hlo;
355 const HloComputation* computation = hlo->parent();
356 if (computation->IsFusionComputation()) {
357 unnested_hlo = computation->FusionInstruction();
358 }
359
360 auto shape_in_range = [&](const Shape& s) {
361 bool in_range = true;
362 ShapeUtil::ForEachSubshape(s, [&](const Shape& sub_shape,
363 const ShapeIndex& /*index*/) {
364 if (sub_shape.IsArray() && !IsInt32(ShapeUtil::ElementsIn(sub_shape))) {
365 in_range = false;
366 }
367 });
368
369 return in_range;
370 };
371
372 llvm::Type* i64_ty = b->getInt64Ty();
373 // Check launch dimension
374 if (!IsInt32(launch_size)) {
375 return i64_ty;
376 }
377
378 // Check the size of result tensors
379 if (!shape_in_range(unnested_hlo->shape())) {
380 return i64_ty;
381 }
382
383 auto hlo_shape_in_range = [&](const HloInstruction* operand) -> bool {
384 return shape_in_range(operand->shape());
385 };
386
387 // Check the size of input tensors
388 if (!absl::c_all_of(unnested_hlo->operands(), hlo_shape_in_range)) {
389 return i64_ty;
390 }
391
392 // Check the size of the internal result tensors
393 if (unnested_hlo->opcode() == HloOpcode::kFusion) {
394 if (!absl::c_all_of(
395 unnested_hlo->fused_instructions_computation()->instructions(),
396 hlo_shape_in_range)) {
397 return i64_ty;
398 }
399 }
400
401 return b->getInt32Ty();
402 }
403
404 // The same as GetIndexTypeForKernel, but works with MLIR ops.
GetIndexTypeForKernel(mlir::Operation * op,int64_t launch_size,llvm::IRBuilder<> * b)405 llvm::Type* GetIndexTypeForKernel(mlir::Operation* op, int64_t launch_size,
406 llvm::IRBuilder<>* b) {
407 auto shape_in_range = [&](const Shape& s) {
408 bool in_range = true;
409 ShapeUtil::ForEachSubshape(s, [&](const Shape& sub_shape,
410 const ShapeIndex& /*index*/) {
411 if (sub_shape.IsArray() && !IsInt32(ShapeUtil::ElementsIn(sub_shape))) {
412 in_range = false;
413 }
414 });
415
416 return in_range;
417 };
418
419 llvm::Type* i64_ty = b->getInt64Ty();
420 // Check launch dimension
421 if (!IsInt32(launch_size)) {
422 return i64_ty;
423 }
424
425 // Check the size of result tensors
426 for (auto result : GetHloOutputs(op)) {
427 if (!shape_in_range(GetShape(result))) {
428 return i64_ty;
429 }
430 }
431
432 auto hlo_shape_in_range = [&](mlir::Value operand) -> bool {
433 return shape_in_range(GetShape(operand));
434 };
435
436 // Check the size of input tensors
437 if (!absl::c_all_of(op->getOperands(), hlo_shape_in_range)) {
438 return i64_ty;
439 }
440
441 // Check the size of the internal result tensors
442 if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
443 auto result = fusion.region().walk([&](mlir::Operation* op) {
444 for (mlir::Value result : op->getResults()) {
445 if (!hlo_shape_in_range(result)) {
446 return mlir::WalkResult::interrupt();
447 }
448 }
449 return mlir::WalkResult::advance();
450 });
451 if (result.wasInterrupted()) {
452 return i64_ty;
453 }
454 }
455
456 return b->getInt32Ty();
457 }
458
459 // Gets the input shape of the ROOT slices, which will be used as the kernel
460 // launch dims. The slice input fusion requires the input shapes of the ROOT
461 // slices to be the same although the (slice) output shapes can be different.
462 //
463 // Returns the input shape of the ROOT slices if all the input shapes of ROOT
464 // slices are the same and the slices are non-strided. Otherwise, returns
465 // FailedPrecondition.
GetConsistentInputShapeForRootSlices(const HloComputation * fused_computation)466 StatusOr<Shape> GetConsistentInputShapeForRootSlices(
467 const HloComputation* fused_computation) {
468 const HloInstruction& root = *fused_computation->root_instruction();
469 if (root.opcode() == HloOpcode::kSlice) {
470 return root.operands()[0]->shape();
471 }
472
473 CHECK_EQ(root.opcode(), HloOpcode::kTuple);
474 const Shape& first_slice_operand_shape =
475 root.operands()[0]->operands()[0]->shape();
476 for (size_t i = 1; i < root.operands().size(); ++i) {
477 const HloInstruction* slice = root.operands()[i];
478 const Shape& operand_shape = slice->operands()[0]->shape();
479 if (!ShapeUtil::EqualIgnoringElementType(first_slice_operand_shape,
480 operand_shape)) {
481 return FailedPrecondition(
482 "Fused slices do not have the same input shape, fused computation = "
483 "%s.",
484 root.parent()->name());
485 }
486 }
487
488 return first_slice_operand_shape;
489 }
490
491 } // namespace
492
IrEmitterUnnested(const HloModuleConfig & hlo_module_config,IrEmitterContext * ir_emitter_context)493 IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
494 IrEmitterContext* ir_emitter_context)
495 : IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/false) {}
496
Create(const HloModuleConfig & hlo_module_config,IrEmitterContext * ir_emitter_context)497 StatusOr<std::unique_ptr<IrEmitterUnnested>> IrEmitterUnnested::Create(
498 const HloModuleConfig& hlo_module_config,
499 IrEmitterContext* ir_emitter_context) {
500 return std::unique_ptr<IrEmitterUnnested>(
501 new IrEmitterUnnested(hlo_module_config, ir_emitter_context));
502 }
503
BuildKernelPrototype(absl::string_view name,absl::Span<const BufferAllocation * const> args)504 llvm::Function* IrEmitterUnnested::BuildKernelPrototype(
505 absl::string_view name, absl::Span<const BufferAllocation* const> args) {
506 // Compute the kernel name. The opcode string may contain "-" which cannot be
507 // in a PTX function name, so sanitize the name before uniquifying it.
508 string kernel_name = ir_emitter_context_->name_uniquer()->GetUniqueName(
509 llvm_ir::SanitizeFunctionName(std::string(name)));
510
511 // Create the kernel and add it to the module.
512 llvm::Module* module = ir_emitter_context_->llvm_module();
513 llvm::LLVMContext& context = module->getContext();
514 llvm::FunctionType* kernel_type = llvm::FunctionType::get(
515 /*Result=*/llvm::Type::getVoidTy(context),
516 std::vector<llvm::Type*>(args.size(), b_.getInt8PtrTy()),
517 /*isVarArg=*/false);
518 llvm::Function* kernel =
519 llvm::Function::Create(kernel_type, llvm::GlobalValue::ExternalLinkage,
520 kernel_name.c_str(), module);
521
522 // Add dereferenceable and alignment information to each of the kernel's
523 // parameters.
524 auto arg_it = kernel->arg_begin();
525 for (size_t arg_no = 0; arg_no < args.size(); ++arg_no) {
526 const BufferAllocation* alloc = args[arg_no];
527 llvm::Argument* fn_arg = &*arg_it;
528 ++arg_it;
529
530 kernel->addDereferenceableAttr(arg_no + 1, alloc->size());
531
532 const int64_t alignment = [&] {
533 if (alloc->is_entry_computation_parameter()) {
534 return kEntryParameterAlignBytes;
535 } else if (alloc->is_constant()) {
536 return kConstantBufferAlignBytes;
537 } else {
538 return kXlaAllocatedBufferAlignBytes;
539 }
540 }();
541
542 kernel->addParamAttr(
543 arg_no,
544 llvm::Attribute::get(context, llvm::Attribute::Alignment, alignment));
545
546 if (alloc->IsPreallocatedTempBuffer()) {
547 fn_arg->setName("temp_buf");
548 } else {
549 fn_arg->setName(StrCat("alloc", alloc->index()));
550 }
551 }
552
553 AnnotateFunctionAsGpuKernel(module, kernel, &b_);
554
555 // TODO(b/65380986): Investigate if adding fast math flags for generated
556 // kernels makes sense.
557
558 // Update the insert point to the entry basic block.
559 llvm::BasicBlock* entry_bb =
560 llvm::BasicBlock::Create(context, /*Name=*/"entry", /*Parent=*/kernel);
561
562 // Emit a "return void" at entry_bb's end, and set the insert point before
563 // that return instruction.
564 b_.SetInsertPoint(llvm::ReturnInst::Create(context, entry_bb));
565
566 return kernel;
567 }
568
GetAllocationSlice(mlir::Value v,std::string * constant_name)569 StatusOr<BufferAllocation::Slice> IrEmitterUnnested::GetAllocationSlice(
570 mlir::Value v, std::string* constant_name) {
571 return xla::gpu::GetAllocationSlice(v, ir_emitter_context_->allocations(),
572 constant_name);
573 }
574
EmitConstant(mlir::Operation * op)575 Status IrEmitterUnnested::EmitConstant(mlir::Operation* op) {
576 auto get_global = mlir::cast<mlir::memref::GetGlobalOp>(op);
577 auto module = get_global->getParentOfType<mlir::ModuleOp>();
578 auto global = mlir::cast<mlir::memref::GlobalOp>(
579 module.lookupSymbol(get_global.name()));
580
581 auto literal = global.initial_value()->dyn_cast<mlir::DenseElementsAttr>();
582 TF_RET_CHECK(literal);
583
584 const bool should_emit_initializer = literal.getType().getNumElements() <= 1;
585
586 TF_ASSIGN_OR_RETURN(int element_bytes,
587 GetElementTypeBytes(literal.getType().getElementType()));
588 llvm::ArrayType* global_type = llvm::ArrayType::get(
589 b_.getInt8Ty(), literal.getType().getNumElements() * element_bytes);
590
591 GpuExecutable::ConstantInfo info;
592 llvm::Constant* initializer;
593 if (should_emit_initializer) {
594 std::vector<uint8> content;
595 TF_RETURN_IF_ERROR(CopyDenseElementsDataToXlaFormat(literal, &content));
596 initializer = llvm::ConstantDataArray::get<uint8>(
597 ir_emitter_context_->llvm_module()->getContext(), content);
598 } else {
599 TF_RETURN_IF_ERROR(
600 CopyDenseElementsDataToXlaFormat(literal, &info.content));
601 initializer = llvm::ConstantAggregateZero::get(global_type);
602 }
603
604 // These globals will be looked up by name by GpuExecutable so we need to
605 // give them an external linkage. Not all of their uses are visible in
606 // the LLVM IR so we can't give then a linkage that merely preserves their
607 // names (like available_externally), we also need to ensure that they stick
608 // around even if they're "unused".
609 //
610 // We may have to be more clever here in the future if we notice that we're
611 // keeping around too many globals because of their linkage.
612 llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable(
613 global_type, /*isConstant=*/should_emit_initializer,
614 llvm::GlobalValue::ExternalLinkage,
615 /*Initializer=*/initializer, global.sym_name(),
616 /*TLMode=*/llvm::GlobalValue::NotThreadLocal,
617 /*AddressSpace=*/0,
618 /*isExternallyInitialized=*/false);
619 global_for_const->setAlignment(llvm::Align(kConstantBufferAlignBytes));
620 ir_emitter_context_->llvm_module()->getGlobalList().push_back(
621 global_for_const);
622
623 info.symbol_name.assign(global.sym_name().begin(), global.sym_name().end());
624
625 info.allocation_index =
626 global->getAttrOfType<mlir::IntegerAttr>("lmhlo.alloc").getInt();
627 ir_emitter_context_->constants().push_back(std::move(info));
628 return Status::OK();
629 }
630
GetConditionalThunkConfig(mlir::lmhlo::CaseOp op,std::vector<ThunkSequence> branch_thunk_sequences)631 static ConditionalThunkConfig GetConditionalThunkConfig(
632 mlir::lmhlo::CaseOp op, std::vector<ThunkSequence> branch_thunk_sequences) {
633 ConditionalThunkConfig config;
634 config.branch_index_is_bool =
635 op.index().getType().cast<mlir::ShapedType>().getElementType().isInteger(
636 /*width=*/1);
637 config.branch_count = op.branches().size();
638 // Pass nullptr as the HloInstruction* to the branch_thunks
639 // constructors because these SequentialThunks are logically "part of"
640 // this ConditionalThunk, and shouldn't be profiled separately from it.
641 config.branch_thunks.reserve(branch_thunk_sequences.size());
642 for (auto& branch_thunk_sequence : branch_thunk_sequences) {
643 config.branch_thunks.emplace_back(new SequentialThunk(
644 Thunk::ThunkInfo(), std::move(branch_thunk_sequence)));
645 }
646 return config;
647 }
648
EmitConditional(mlir::Operation * op)649 Status IrEmitterUnnested::EmitConditional(mlir::Operation* op) {
650 auto conditional = mlir::cast<mlir::lmhlo::CaseOp>(op);
651
652 std::vector<ThunkSequence> branch_thunks;
653
654 int branch_count = conditional.branches().size();
655 branch_thunks.reserve(branch_count);
656
657 for (int j = 0; j < branch_count; ++j) {
658 mlir::Region* branch_computation = &conditional.branches()[j];
659 TF_ASSIGN_OR_RETURN(
660 auto ir_emitter,
661 IrEmitterUnnested::Create(hlo_module_config_, ir_emitter_context_));
662 TF_RETURN_IF_ERROR(ir_emitter->EmitLmhloRegion(branch_computation));
663 branch_thunks.push_back(std::move(*ir_emitter->ConsumeThunkSequence()));
664 }
665
666 ConditionalThunkConfig config =
667 GetConditionalThunkConfig(conditional, std::move(branch_thunks));
668
669 TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(conditional.index()));
670 AddThunkToThunkSequence(std::unique_ptr<Thunk>(
671 new ConditionalThunk(GetThunkInfo(op), std::move(config), slice)));
672 return Status::OK();
673 }
674
675 // Input = {dynamic array(with dynamic dimension meta data at the end)}
676 // Output = {static array, dynamic_dim0, dynamic_dim1}
EmitPadToStatic(mlir::Operation * op)677 Status IrEmitterUnnested::EmitPadToStatic(mlir::Operation* op) {
678 // TODO(jurahul): Create an op to represent PadToStatic.
679 auto pad_to_static = mlir::cast<mlir::lmhlo::CustomCallOp>(op);
680 int unroll_factor = 1;
681 std::string ir_name = mlir::GetNameFromLoc(pad_to_static.getLoc());
682
683 const Shape& input_shape = GetShape(pad_to_static.args().front());
684 TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions,
685 CalculateLaunchDimensions(
686 input_shape, ir_emitter_context_->gpu_device_info(),
687 {unroll_factor}));
688 std::vector<llvm_ir::IrArray> ir_arrays;
689 TF_ASSIGN_OR_RETURN(auto kernel_thunk,
690 BuildKernelThunk(pad_to_static, GetThunkInfo(op),
691 &ir_arrays, launch_dimensions));
692
693 const llvm_ir::IrArray source_array = ir_arrays[0];
694 const llvm_ir::IrArray output_array = ir_arrays[1];
695 auto output_dim_arrays =
696 absl::Span<const llvm_ir::IrArray>(ir_arrays).subspan(2);
697
698 // pseudo code for PadToStatic on a 2d array
699 // int* source_array = input[0];
700 // int* dest_array = output[0];
701 llvm::Value* source_buffer = source_array.GetBasePointer();
702 llvm::Value* raw_buffer =
703 b_.CreateBitCast(source_buffer, b_.getInt8Ty()->getPointerTo());
704
705 // TODO(jurahul): input_shape here is the static shape of the input (which has
706 // a dynamic shape in XLA). Currently, we are mapping that to a static shaped
707 // memref. When we change that to a more appropriate representation in MLIR,
708 // fix this code to correctly deduce the static shape backing the dynamically
709 // shaped memref.
710 int64_t raw_data_size = ShapeUtil::ByteSizeOf(input_shape);
711
712 // int* dyn_dim0_size = source_array + meta_data_offset;
713 // int* dyn_dim1_size = source_array + meta_data_offset + sizeof(int);
714 std::vector<llvm::Value*> dynamic_dims;
715 for (int64_t i = 1; i < pad_to_static.output().size(); ++i) {
716 // Dynamic size of each dimension is attached at the end of the source
717 // array(operand(0)). We need to extract these value.
718 const Shape& dim_shape = GetShape(pad_to_static.output()[i]);
719 TF_RET_CHECK(Shape::Equal()(dim_shape, ShapeUtil::MakeScalarShape(S32)));
720
721 const int64_t dim_index = i - 1;
722 llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32(
723 b_.getInt8Ty(), raw_buffer, raw_data_size + dim_index * sizeof(int32));
724 llvm::Value* dyn_dim_size = b_.CreateLoad(
725 b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo()),
726 "dyn_dim_size");
727 dynamic_dims.push_back(dyn_dim_size);
728 }
729
730 // only one thread need to store the dynamic index
731 // int thread_id = GetThreadId();
732 // int block_id = GetBlockId();
733 // if (thread_id == 0 && block_id == 0) {
734 // *output[1] = *dyn_dim0_size;
735 // *output[2] = *dyn_dim1_size;
736 // }
737 KernelSupportLibrary{&b_}.If("is_thred_0", IsBlock0Thread0(&b_), [&] {
738 for (int64_t i = 1; i < pad_to_static.output().size(); ++i) {
739 const int64_t dim_index = i - 1;
740 llvm::Value* dest_dim_size_address =
741 output_dim_arrays[dim_index].GetBasePointer();
742 // output[i] stores dynamic_dim_(i-1)
743 b_.CreateStore(dynamic_dims[i - 1],
744 b_.CreateBitCast(dest_dim_size_address,
745 b_.getInt32Ty()->getPointerTo()));
746 }
747 });
748
749 // int dyn_element_total = 1;
750 // dyn_element_total *= *dyn_dim0_size;
751 // dyn_element_total *= *dyn_dim1_size;
752 llvm::Value* dyn_element_total = llvm::ConstantInt::get(b_.getInt32Ty(), 1);
753 for (llvm::Value* dynamic_dim : dynamic_dims) {
754 dyn_element_total = b_.CreateMul(dyn_element_total, dynamic_dim,
755 /*Name=*/"dyn_element_total");
756 }
757
758 // linear_index = block_id * threads_per_block + thread_id;
759 // if (linear_index < max_num_element) {
760 // Index static_index =
761 // delinerized(linerized_index, static_dim0_size, static_dim1_size);
762 // if (linerized_index < dyn_element_total) {
763 // Index dyn_index =
764 // delinerized(linerized_index, *dyn_dim0_size, *dyn_dim1_size);
765 // dest_array[dyn_index.dim0][dyn_index.dim1] =
766 // source_array[static_index.dim0][static_index.dim1];
767 // }
768 // }
769 llvm_ir::LoopEmitter::BodyEmitter body_generator =
770 [&](const llvm_ir::IrArray::Index& array_index) -> Status {
771 llvm::Value* linearIndex =
772 array_index.Linearize(input_shape.dimensions(), &b_);
773 auto if_in_dyn_bounds = llvm_ir::EmitIfThenElse(
774 b_.CreateICmpULT(linearIndex, dyn_element_total),
775 llvm_ir::IrName(ir_name, "in_dyn_bounds"), &b_, false);
776 // Set IR builder insertion point to the body of the if structure.
777 llvm_ir::SetToFirstInsertPoint(if_in_dyn_bounds.true_block, &b_);
778 llvm_ir::IrArray::Index dyn_index(linearIndex, input_shape,
779 absl::MakeSpan(dynamic_dims), &b_);
780 output_array.EmitWriteArrayElement(
781 dyn_index,
782 source_array.EmitReadArrayElement(array_index, &b_, /*name=*/""), &b_,
783 /*use_linear_index=*/false);
784 return Status::OK();
785 };
786
787 const Shape& data_shape = GetShape(pad_to_static.output().front());
788 TF_RETURN_IF_ERROR(
789 ParallelLoopEmitter(body_generator, data_shape, launch_dimensions, &b_,
790 {unroll_factor})
791 .EmitLoop(ir_name,
792 GetIndexTypeForKernel(
793 pad_to_static, launch_dimensions.launch_bound(), &b_)));
794 thunk_sequence_.emplace_back(std::move(kernel_thunk));
795 return Status::OK();
796 }
797
798 // Input = {dynamic array(with dynamic dimension meta data at the end)}
799 // Output = {static array, dynamic_dim0, dynamic_dim1}
EmitSliceToDynamic(mlir::Operation * op)800 Status IrEmitterUnnested::EmitSliceToDynamic(mlir::Operation* op) {
801 // TODO(jurahul): Create an op to represent SliceToDynamic.
802 auto slice_to_dynamic = mlir::cast<mlir::lmhlo::CustomCallOp>(op);
803 int unroll_factor = 1;
804 std::string ir_name = mlir::GetNameFromLoc(slice_to_dynamic.getLoc());
805
806 const Shape& input_shape = GetShape(slice_to_dynamic.args().front());
807 TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions,
808 CalculateLaunchDimensions(
809 input_shape, ir_emitter_context_->gpu_device_info(),
810 {unroll_factor}));
811 std::vector<llvm_ir::IrArray> ir_arrays;
812 TF_ASSIGN_OR_RETURN(auto kernel_thunk,
813 BuildKernelThunk(slice_to_dynamic, GetThunkInfo(op),
814 &ir_arrays, launch_dimensions));
815
816 TF_RET_CHECK(slice_to_dynamic.output().size() == 1);
817 const Shape& data_shape = GetShape(slice_to_dynamic.output().front());
818
819 // TODO(jurahul): data_shape here is the static shape of the output (which has
820 // a dynamic shape in XLA). Currently, we are mapping that to a static shaped
821 // memref. When we change that to a more appropriate representation in MLIR,
822 // fix this code to correctly deduce the static shape backing the dynamically
823 // shaped memref.
824
825 // calculate the location where metadata needs to be inserted
826 // int* dyn_dim0_size = dest_array + meta_data_offset;
827 // int* dyn_dim1_size = dest_array + meta_data_offset + sizeof(int);
828 int32_t raw_data_size = ShapeUtil::ByteSizeOf(data_shape);
829
830 // pseudo code for sliceToDynamic on a 2d array
831 // int* source_array = input[0];
832 // int* dest_array = output[0];
833 const llvm_ir::IrArray data_array = ir_arrays.back();
834 llvm::Value* dest_buffer = data_array.GetBasePointer();
835 llvm::Value* raw_buffer =
836 b_.CreateBitCast(dest_buffer, b_.getInt8Ty()->getPointerTo());
837
838 // Load dynamic dimensions from memory.
839 std::vector<llvm::Value*> dynamic_dims;
840 for (int64_t i = 1; i < slice_to_dynamic.args().size(); ++i) {
841 // const int64 dim_index = i - 1;
842 llvm::Value* source_buffer = ir_arrays[i].GetBasePointer();
843 llvm::LoadInst* dyn_dim_size = b_.CreateLoad(source_buffer, "dyn_dim_size");
844 dynamic_dims.push_back(dyn_dim_size);
845 }
846
847 // only one thread need to store the dynamic index
848 // int thread_id = GetThreadId();
849 // int block_id = GetBlockId();
850 // if (thread_id == 0 && block_id == 0) {
851 // *dyn_dim0_size = *output[1];
852 // *dyn_dim1_size = *output[2];
853 // }
854 KernelSupportLibrary{&b_}.If("is_thred_0", IsBlock0Thread0(&b_), [&] {
855 for (int64_t i = 1; i < slice_to_dynamic.args().size(); ++i) {
856 const int64_t dim_index = i - 1;
857 llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32(
858 b_.getInt8Ty(), raw_buffer,
859 raw_data_size + dim_index * sizeof(int32));
860 // output[i] stores dynamic_dim_(i-1)
861 b_.CreateStore(
862 dynamic_dims[dim_index],
863 b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo()));
864 }
865 });
866
867 // int dyn_element_total = 1;
868 // dyn_element_total *= dyn_dim0_size;
869 // dyn_element_total *= dyn_dim1_size;
870 llvm::Value* dyn_element_total = llvm::ConstantInt::get(b_.getInt32Ty(), 1);
871 for (llvm::Value* dynamic_dim : dynamic_dims) {
872 dyn_element_total = b_.CreateMul(dyn_element_total, dynamic_dim,
873 /*Name=*/"dyn_element_total");
874 }
875
876 // linear_index = block_id * threads_per_block + thread_id;
877 // if (linear_index < max_num_element) {
878 // Index static_index =
879 // delinerized(linerized_index, static_dim0_size, static_dim1_size);
880 // if (linerized_index < dyn_element_total) {
881 // Index dyn_index =
882 // delinerized(linerized_index, *dyn_dim0_size, *dyn_dim1_size);
883 // dest_array[static_index.dim0][static_index.di] =
884 // source_array[dyn_index.dim0][dyn_index.dim1];
885 // }
886 // }
887 llvm_ir::LoopEmitter::BodyEmitter body_generator =
888 [&](const llvm_ir::IrArray::Index& array_index) -> Status {
889 llvm::Value* linearIndex =
890 array_index.Linearize(input_shape.dimensions(), &b_);
891 auto if_in_dyn_bounds = llvm_ir::EmitIfThenElse(
892 b_.CreateICmpULT(linearIndex, dyn_element_total),
893 llvm_ir::IrName(ir_name, "in_dyn_bounds"), &b_, false);
894 // Set IR builder insertion point to the body of the if structure.
895 llvm_ir::SetToFirstInsertPoint(if_in_dyn_bounds.true_block, &b_);
896 llvm_ir::IrArray::Index dyn_index(linearIndex, input_shape,
897 absl::MakeSpan(dynamic_dims), &b_);
898
899 data_array.EmitWriteArrayElement(
900 array_index,
901 ir_arrays[0].EmitReadArrayElement(dyn_index, &b_, /*name=*/"",
902 /*use_linear_index=*/false),
903 &b_);
904 return Status::OK();
905 };
906
907 TF_RETURN_IF_ERROR(
908 ParallelLoopEmitter(body_generator, data_shape, launch_dimensions, &b_,
909 {unroll_factor})
910 .EmitLoop(ir_name, GetIndexTypeForKernel(
911 slice_to_dynamic,
912 launch_dimensions.launch_bound(), &b_)));
913 thunk_sequence_.emplace_back(std::move(kernel_thunk));
914 return Status::OK();
915 }
916
EmitConvolutionThunk(mlir::Operation * op)917 Status IrEmitterUnnested::EmitConvolutionThunk(mlir::Operation* op) {
918 using mlir::dyn_cast;
919 using mlir::lmhlo_gpu::Activation;
920 using mlir::lmhlo_gpu::ConvBackwardFilterOp;
921 using mlir::lmhlo_gpu::ConvBackwardInputOp;
922 using mlir::lmhlo_gpu::ConvForwardFusedOp;
923 using mlir::lmhlo_gpu::ConvForwardFusedSideInputOp;
924 using mlir::lmhlo_gpu::ConvForwardOp;
925
926 // Last 2 operands of the convolution operation are the result and scratch.
927 std::vector<BufferAllocation::Slice> operand_slices;
928 int64_t num_operands = op->getNumOperands();
929 operand_slices.reserve(num_operands - 2);
930 for (mlir::Value operand : op->getOperands().drop_back(2)) {
931 TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(operand));
932 operand_slices.push_back(slice);
933 }
934
935 mlir::Value conv_result = op->getOperand(num_operands - 2);
936 mlir::Value scratch_result = op->getOperand(num_operands - 1);
937 TF_ASSIGN_OR_RETURN(auto conv_result_slice, GetAllocationSlice(conv_result));
938 TF_ASSIGN_OR_RETURN(auto scratch_slice, GetAllocationSlice(scratch_result));
939
940 auto apply_layout = [](const Shape& shape, mlir::ArrayAttr layout_attrib) {
941 mlir::SmallVector<int64, 4> minor_to_major = llvm::to_vector<4>(
942 llvm::map_range(layout_attrib, [](mlir::Attribute a) -> int64 {
943 return static_cast<int64>(a.cast<mlir::IntegerAttr>().getInt());
944 }));
945 return ShapeUtil::MakeShapeWithLayout(shape.element_type(),
946 shape.dimensions(), minor_to_major);
947 };
948
949 GpuConvDescriptor descriptor;
950
951 auto fill_conv_descriptor = [&](auto op) {
952 descriptor.operand0_shape = apply_layout(
953 GetShape(op->getOperand(0)), op.backend_config().operand_0_layout());
954 descriptor.operand1_shape = apply_layout(
955 GetShape(op->getOperand(1)), op.backend_config().operand_1_layout());
956 descriptor.result_shape = apply_layout(GetShape(conv_result),
957 op.backend_config().result_layout());
958 descriptor.dnums = ConvertConvDimensionNumbers(op.dimension_numbers());
959 descriptor.scratch_size = scratch_slice.size();
960 mlir::DenseIntElementsAttr window_strides = op.window_strides().getValue();
961 mlir::DenseIntElementsAttr padding = op.padding().getValue();
962 mlir::DenseIntElementsAttr lhs_dilation = op.lhs_dilation().getValue();
963 mlir::DenseIntElementsAttr rhs_dilation = op.rhs_dilation().getValue();
964 mlir::DenseElementsAttr window_reversal = op.window_reversal().getValue();
965 for (auto index : llvm::seq<int>(0, window_strides.getNumElements())) {
966 WindowDimension* dim = descriptor.window.add_dimensions();
967 // Window size for a convolution is the same as the kernel size.
968 // Kernel size of the convolution is operand1_shape. We need to look at
969 // the convolution dimension numbers kernel spatial dimensions to get
970 // the window size.
971 int kernel_dim = descriptor.dnums.kernel_spatial_dimensions(index);
972 dim->set_size(descriptor.operand0_shape.dimensions(kernel_dim));
973 dim->set_stride(window_strides.getValue<int64>(index));
974 dim->set_padding_low(padding.getValue<int64>(index));
975 dim->set_padding_high(padding.getValue<int64>(index));
976 dim->set_base_dilation(lhs_dilation.getValue<int64>(index));
977 dim->set_window_dilation(rhs_dilation.getValue<int64>(index));
978 dim->set_window_reversal(window_reversal.getValue<bool>(index));
979 }
980 descriptor.feature_group_count = op.feature_group_count();
981 descriptor.backend_config.set_algorithm(
982 op.backend_config().algorithm().getInt());
983 descriptor.backend_config.set_tensor_ops_enabled(
984 op.backend_config().tensor_ops_enabled().getValue());
985 descriptor.backend_config.set_conv_result_scale(
986 op.result_scale().convertToDouble());
987 };
988
989 auto set_activation_mode = [&](auto op) -> Status {
990 TF_ASSIGN_OR_RETURN(stream_executor::dnn::ActivationMode activation_mode,
991 ConvertConvActivationMode(op.activation_mode()));
992 descriptor.backend_config.set_activation_mode(
993 static_cast<int64>(activation_mode));
994 return Status::OK();
995 };
996
997 if (auto conv = dyn_cast<ConvForwardOp>(op)) {
998 descriptor.kind = CudnnConvKind::kForward;
999 fill_conv_descriptor(conv);
1000 } else if (auto conv = dyn_cast<ConvBackwardInputOp>(op)) {
1001 descriptor.kind = CudnnConvKind::kBackwardInput;
1002 fill_conv_descriptor(conv);
1003 } else if (auto conv = dyn_cast<ConvBackwardFilterOp>(op)) {
1004 descriptor.kind = CudnnConvKind::kBackwardFilter;
1005 fill_conv_descriptor(conv);
1006 } else if (auto conv = dyn_cast<ConvForwardFusedOp>(op)) {
1007 descriptor.kind = CudnnConvKind::kForwardActivation;
1008 fill_conv_descriptor(conv);
1009 TF_RETURN_IF_ERROR(set_activation_mode(conv));
1010 } else if (auto conv = dyn_cast<ConvForwardFusedSideInputOp>(op)) {
1011 descriptor.kind = CudnnConvKind::kForwardActivation;
1012 fill_conv_descriptor(conv);
1013 TF_RETURN_IF_ERROR(set_activation_mode(conv));
1014 descriptor.backend_config.set_side_input_scale(
1015 conv.side_input_scale().convertToDouble());
1016 } else {
1017 return InternalError("Unexpected operation");
1018 }
1019 TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(descriptor, ""));
1020 AddThunkToThunkSequence(absl::make_unique<ConvolutionThunk>(
1021 GetThunkInfo(op), std::move(config), std::move(operand_slices),
1022 conv_result_slice, scratch_slice));
1023 return Status::OK();
1024 }
1025
EmitGemmThunk(mlir::Operation * op)1026 Status IrEmitterUnnested::EmitGemmThunk(mlir::Operation* op) {
1027 auto make_thunk_for_gemm =
1028 [&](auto op, absl::optional<BufferAllocation::Slice> bias = absl::nullopt,
1029 absl::optional<double> gemm_bias_beta = absl::nullopt,
1030 bool implements_whole_instruction =
1031 true) -> StatusOr<std::unique_ptr<Thunk>> {
1032 TF_ASSIGN_OR_RETURN(auto lhs, GetAllocationSlice(op.lhs()));
1033 TF_ASSIGN_OR_RETURN(auto rhs, GetAllocationSlice(op.rhs()));
1034 TF_ASSIGN_OR_RETURN(auto output, GetAllocationSlice(op.output()));
1035 std::vector<BufferAllocation::Slice> inputs = {lhs, rhs};
1036 if (bias.has_value()) {
1037 inputs.push_back(bias.value());
1038 }
1039
1040 if (IsBefThunkEnabled() && op.lhs_stride() && op.rhs_stride()) {
1041 // TODO(loreno): TFRT support for zero-strided gemm calls
1042 return CreateBefThunk(GetThunkInfo(op), op, inputs,
1043 std::vector<BufferAllocation::Slice>{output});
1044 }
1045
1046 GpuGemmConfig config;
1047 GemmBackendConfig& backend = config.backend_config;
1048 config.output_shape = GetShape(op.output());
1049 config.lhs_shape = GetShape(op.lhs());
1050 config.rhs_shape = GetShape(op.rhs());
1051 backend.Clear();
1052 if (op.algorithm()) {
1053 backend.set_selected_algorithm(*op.algorithm());
1054 }
1055 backend.set_alpha_real(op.alpha_real().convertToDouble());
1056 backend.set_alpha_imag(op.alpha_imag().convertToDouble());
1057 backend.set_batch_size(op.batch_size());
1058 if (gemm_bias_beta.has_value()) {
1059 backend.set_beta(gemm_bias_beta.value());
1060 }
1061 backend.set_lhs_stride(op.lhs_stride());
1062 backend.set_rhs_stride(op.rhs_stride());
1063
1064 auto& dims = *backend.mutable_dot_dimension_numbers();
1065 auto mlir_dims = op.dot_dimension_numbers();
1066
1067 auto fill_dims = [](mlir::DenseElementsAttr mlir_dim, auto* config_attrs) {
1068 for (llvm::APInt e : mlir_dim.getIntValues())
1069 config_attrs->Add(e.getSExtValue());
1070 };
1071 fill_dims(mlir_dims.lhs_batching_dimensions(),
1072 dims.mutable_lhs_batch_dimensions());
1073 fill_dims(mlir_dims.rhs_batching_dimensions(),
1074 dims.mutable_rhs_batch_dimensions());
1075 fill_dims(mlir_dims.lhs_contracting_dimensions(),
1076 dims.mutable_lhs_contracting_dimensions());
1077 fill_dims(mlir_dims.rhs_contracting_dimensions(),
1078 dims.mutable_rhs_contracting_dimensions());
1079
1080 return std::unique_ptr<Thunk>(
1081 new GemmThunk(GetThunkInfo(op), std::move(config), lhs, rhs, output,
1082 implements_whole_instruction));
1083 };
1084
1085 TF_ASSIGN_OR_RETURN(auto thunk, [&]() -> StatusOr<std::unique_ptr<Thunk>> {
1086 if (auto gemm = mlir::dyn_cast<mlir::lmhlo_gpu::GEMMOp>(op)) {
1087 return make_thunk_for_gemm(gemm);
1088 }
1089
1090 if (auto gemm = mlir::dyn_cast<mlir::lmhlo_gpu::GEMM_BiasOp>(op)) {
1091 double gemm_bias_beta = gemm.beta().convertToDouble();
1092 TF_ASSIGN_OR_RETURN(auto bias, GetAllocationSlice(gemm.bias()));
1093 TF_ASSIGN_OR_RETURN(auto output, GetAllocationSlice(gemm.output()));
1094
1095 // The bias is passed inside the output buffer. If those buffers are
1096 // shared we can just use it, otherwise copy the bias values into the
1097 // output buffer first.
1098 if (bias == output) {
1099 return make_thunk_for_gemm(gemm, bias, gemm_bias_beta);
1100 }
1101
1102 ThunkSequence thunks;
1103 thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
1104 Thunk::ThunkInfo(),
1105 /*source_buffer=*/bias,
1106 /*destination_buffer=*/output,
1107 /*mem_size=*/
1108 ShapeUtil::ByteSizeOf(GetShape(gemm.output()))));
1109 TF_ASSIGN_OR_RETURN(
1110 auto thunk,
1111 make_thunk_for_gemm(gemm, bias, gemm_bias_beta,
1112 /*implements_whole_instruction=*/false));
1113 thunks.push_back(std::move(thunk));
1114 return std::unique_ptr<Thunk>(
1115 new SequentialThunk(GetThunkInfo(op), std::move(thunks)));
1116 }
1117
1118 return tensorflow::errors::Internal("Unexpected op.");
1119 }());
1120
1121 AddThunkToThunkSequence(std::move(thunk));
1122 return Status::OK();
1123 }
1124
1125 namespace {
1126 // An MLIR value and its name as defined in the ODS spec.
1127 struct NamedValue {
1128 mlir::Value value;
1129 absl::string_view name;
1130 };
1131
1132 // Verifies that the given batch norm is well formed for thunk emission. This
1133 // requires that all statistics operands (mean, stddev etc) are F32 types and
1134 // all the non-statistics operands need to match in shape, element type, and
1135 // layout (which maps to them having the same memref type).
VerifyBatchNormForThunkEmission(mlir::ArrayRef<NamedValue> statistics_operands,mlir::ArrayRef<NamedValue> other_operands)1136 Status VerifyBatchNormForThunkEmission(
1137 mlir::ArrayRef<NamedValue> statistics_operands,
1138 mlir::ArrayRef<NamedValue> other_operands) {
1139 for (const NamedValue& v : statistics_operands) {
1140 // Note: MLIR verification will ensure that the operands of the batchnorm
1141 // LHLO are valid memref types.
1142 if (!v.value.getType().cast<mlir::MemRefType>().getElementType().isF32()) {
1143 return Unimplemented("Operand %s of batch norm should have F32 type",
1144 v.name);
1145 }
1146 }
1147 if (other_operands.empty()) {
1148 return Status::OK();
1149 }
1150
1151 mlir::Type first_type = other_operands.front().value.getType();
1152 absl::string_view first_name = other_operands.front().name;
1153
1154 for (const NamedValue& v : other_operands.drop_front(1)) {
1155 if (v.value.getType() != first_type) {
1156 return Unimplemented("%s and %s for batch norm should have same types",
1157 v.name, first_name);
1158 }
1159 }
1160
1161 return Status::OK();
1162 }
1163
1164 // Determine if we enable the row optimized codegen. When we have a
1165 // fusion with only point-wise operations, scalar broadcasting and row
1166 // broadcasting, we can trigger a kernel that vectorize the row loads.
1167 // This speed up the kernel, in particular on A100.
1168 // Returns a pair<bool, int>. The bool mean should we try to enable
1169 // row vectorization. The int is the number of inputs with the higher
1170 // rank.
RowVectorizationEnabled(mlir::lmhlo::FusionOp fusion)1171 std::pair<bool, int> RowVectorizationEnabled(mlir::lmhlo::FusionOp fusion) {
1172 const auto is_row_major = [](mlir::Value value) {
1173 // Only tested when the inputs are row-major. So only
1174 // enable that case. Maybe it would works if only the
1175 // inner dimensions is contiguous.
1176 return LayoutUtil::IsMonotonicWithDim0Major(GetShape(value).layout());
1177 };
1178 bool row_vectorized =
1179 fusion.getFusionResults().size() == 1 && // Not tested with MOF.
1180 absl::c_all_of(GetHloOperands(fusion), is_row_major) &&
1181 absl::c_all_of(GetHloOutputs(fusion), is_row_major);
1182
1183 // Check that the operations in the fusion are supported. Each
1184 // supported operation (or category) must be manually vetted as XLA
1185 // only unrolls and relies on LLVM to vectorize. But this is brittle.
1186 // Currently tested and supported operations:
1187 // Elementwise, scalar and row broadcasting.
1188 //
1189 // We also detect at the same time if there is a row broadcasting
1190 // operation.
1191 bool some_row_broadcasting = false;
1192 auto out_rank =
1193 fusion.getFusionResults()[0].getType().cast<mlir::ShapedType>().getRank();
1194 int num_big_inputs = 0;
1195 for (mlir::Operation& op : fusion.region().front()) {
1196 if (auto load = mlir::dyn_cast<mlir::memref::TensorLoadOp>(op)) {
1197 auto rank = load.getResult().getType().cast<mlir::ShapedType>().getRank();
1198 num_big_inputs += static_cast<int>(rank == out_rank);
1199 continue;
1200 } else if (mlir::isa<mlir::memref::TensorStoreOp, mlir::lmhlo::TerminatorOp,
1201 mlir::mhlo::ReturnOp, mlir::mhlo::ConstOp,
1202 mlir::lmhlo::ConstOp>(op)) {
1203 continue;
1204 }
1205 HloOpcode opcode = *MhloToHloOpcode(&op);
1206 if (HloInstruction::IsOpElementwise(opcode)) {
1207 continue;
1208 }
1209
1210 if (auto broadcast = mlir::dyn_cast<mlir::mhlo::BroadcastInDimOp>(op)) {
1211 if (broadcast.broadcast_dimensions().size() == 0) {
1212 continue;
1213 }
1214 std::vector<int64> broadcast_dimensions;
1215 for (const llvm::APInt& int_value : broadcast.broadcast_dimensions()) {
1216 broadcast_dimensions.push_back(int_value.getSExtValue());
1217 }
1218
1219 auto rank = GetShape(broadcast.getResult()).rank();
1220 if (broadcast_dimensions.size() == 1 &&
1221 broadcast_dimensions.back() == (rank - 1)) {
1222 some_row_broadcasting = true;
1223 continue;
1224 }
1225 }
1226 VLOG(2) << "Row vectorization not enabled due to this op: "
1227 << MlirToString(&op);
1228 return std::make_pair(false, 0);
1229 }
1230 // Trigger only when there is a row broadcasting.
1231 return std::make_pair(row_vectorized && some_row_broadcasting,
1232 num_big_inputs);
1233 }
1234 } // namespace
1235
EmitBatchNormThunk(mlir::Operation * op)1236 Status IrEmitterUnnested::EmitBatchNormThunk(mlir::Operation* op) {
1237 auto get_batch_norm_config = [](auto op, mlir::Value output) {
1238 CudnnBatchNormConfig config;
1239 config.output_shape = GetShape(output);
1240 config.output_type = config.output_shape.element_type();
1241 config.epsilon = op.epsilon().convertToFloat();
1242 config.feature_index = op.feature_index();
1243 return config;
1244 };
1245
1246 // The statistics operands for batch norm operations need to be FP32 type.
1247 // And the rest of the operands need match in shape, layout, and element type
1248 // to match.
1249 if (auto bn_train =
1250 mlir::dyn_cast<mlir::lmhlo_gpu::BatchNormTrainingOp>(op)) {
1251 TF_RETURN_IF_ERROR(VerifyBatchNormForThunkEmission(
1252 /*statistics_operands=*/
1253 {{bn_train.scale(), "scale"},
1254 {bn_train.offset(), "offset"},
1255 {bn_train.batch_mean(), "batch_mean"},
1256 {bn_train.batch_stddev(), "batch_stddev"}},
1257 /*other_operands=*/
1258 {{bn_train.operand(), "operand"}, {bn_train.output(), "output"}}));
1259 TF_ASSIGN_OR_RETURN(auto operand, GetAllocationSlice(bn_train.operand()));
1260 TF_ASSIGN_OR_RETURN(auto scale, GetAllocationSlice(bn_train.scale()));
1261 TF_ASSIGN_OR_RETURN(auto offset, GetAllocationSlice(bn_train.offset()));
1262
1263 // BatchNormTraining returns a tuple of three elements: data, calculated
1264 // mean, and calculated 1/sqrt(variance + epsilon).
1265 TF_ASSIGN_OR_RETURN(auto output_data,
1266 GetAllocationSlice(bn_train.output()));
1267 TF_ASSIGN_OR_RETURN(auto output_mean,
1268 GetAllocationSlice(bn_train.batch_mean()));
1269 TF_ASSIGN_OR_RETURN(auto output_inv_stddev,
1270 GetAllocationSlice(bn_train.batch_stddev()));
1271
1272 AddThunkToThunkSequence(
1273 absl::make_unique<CudnnBatchNormForwardTrainingThunk>(
1274 GetThunkInfo(op),
1275 /*config=*/get_batch_norm_config(bn_train, bn_train.output()),
1276 /*operand=*/operand,
1277 /*scale=*/scale,
1278 /*offset=*/offset,
1279 /*output_data=*/output_data,
1280 /*output_mean=*/output_mean,
1281 /*output_inv_stddev=*/output_inv_stddev));
1282 return Status::OK();
1283 }
1284
1285 if (auto bn_grad = mlir::dyn_cast<mlir::lmhlo_gpu::BatchNormGradOp>(op)) {
1286 TF_RETURN_IF_ERROR(VerifyBatchNormForThunkEmission(
1287 /*statistics_operands=*/
1288 {{bn_grad.scale(), "scale"},
1289 {bn_grad.mean(), "mean"},
1290 {bn_grad.stddev(), "stddev"},
1291 {bn_grad.grad_scale(), "grad_scale"},
1292 {bn_grad.grad_offset(), "grad_offset"}},
1293 /*other_operands=*/
1294 {{bn_grad.operand(), "operand"},
1295 {bn_grad.grad_output(), "grad_output"},
1296 {bn_grad.grad_operand(), "grad_operand"}}));
1297
1298 TF_ASSIGN_OR_RETURN(auto operand, GetAllocationSlice(bn_grad.operand()));
1299 TF_ASSIGN_OR_RETURN(auto scale, GetAllocationSlice(bn_grad.scale()));
1300 TF_ASSIGN_OR_RETURN(auto mean, GetAllocationSlice(bn_grad.mean()));
1301 TF_ASSIGN_OR_RETURN(auto inv_stddev, GetAllocationSlice(bn_grad.stddev()));
1302 TF_ASSIGN_OR_RETURN(auto grad_output,
1303 GetAllocationSlice(bn_grad.grad_output()));
1304
1305 // BatchNormGrad returns a tuple of three elements: grad_data, grad_scale,
1306 // grad_offset.
1307 TF_ASSIGN_OR_RETURN(auto output_grad_data,
1308 GetAllocationSlice(bn_grad.grad_operand()));
1309 TF_ASSIGN_OR_RETURN(auto output_grad_scale,
1310 GetAllocationSlice(bn_grad.grad_scale()));
1311 TF_ASSIGN_OR_RETURN(auto output_grad_offset,
1312 GetAllocationSlice(bn_grad.grad_offset()));
1313
1314 CudnnBatchNormConfig config;
1315 config.output_shape = GetShape(bn_grad.grad_output());
1316 config.output_type = config.output_shape.element_type();
1317 config.epsilon = bn_grad.epsilon().convertToFloat();
1318 config.feature_index = bn_grad.feature_index();
1319
1320 AddThunkToThunkSequence(absl::make_unique<CudnnBatchNormBackwardThunk>(
1321 GetThunkInfo(op),
1322 /*config=*/get_batch_norm_config(bn_grad, bn_grad.grad_output()),
1323 /*operand=*/operand,
1324 /*scale=*/scale,
1325 /*mean=*/mean,
1326 /*inv_stddev=*/inv_stddev,
1327 /*grad_output=*/grad_output,
1328 /*output_grad_data=*/output_grad_data,
1329 /*output_grad_scale=*/output_grad_scale,
1330 /*output_grad_offset=*/output_grad_offset));
1331 return Status::OK();
1332 }
1333
1334 if (auto bn_inference =
1335 mlir::dyn_cast<mlir::lmhlo_gpu::BatchNormInferenceOp>(op)) {
1336 TF_RETURN_IF_ERROR(
1337 VerifyBatchNormForThunkEmission(/*statistics_operands=*/
1338 {{bn_inference.scale(), "scale"},
1339 {bn_inference.offset(), "offset"},
1340 {bn_inference.mean(), "mean"},
1341 {bn_inference.stddev(), "stddev"}},
1342 /*other_operands=*/
1343 {{bn_inference.operand(), "operand"},
1344 {bn_inference.output(), "output"}}));
1345
1346 TF_ASSIGN_OR_RETURN(auto operand,
1347 GetAllocationSlice(bn_inference.operand()));
1348 TF_ASSIGN_OR_RETURN(auto scale, GetAllocationSlice(bn_inference.scale()));
1349 TF_ASSIGN_OR_RETURN(auto offset, GetAllocationSlice(bn_inference.offset()));
1350 TF_ASSIGN_OR_RETURN(auto mean, GetAllocationSlice(bn_inference.mean()));
1351 TF_ASSIGN_OR_RETURN(auto variance,
1352 GetAllocationSlice(bn_inference.stddev()));
1353 TF_ASSIGN_OR_RETURN(auto output, GetAllocationSlice(bn_inference.output()));
1354
1355 AddThunkToThunkSequence(absl::make_unique<
1356 CudnnBatchNormForwardInferenceThunk>(
1357 GetThunkInfo(op),
1358 /*config=*/get_batch_norm_config(bn_inference, bn_inference.output()),
1359 /*operand=*/operand,
1360 /*scale=*/scale,
1361 /*offset=*/offset,
1362 /*mean=*/mean,
1363 /*variance=*/variance,
1364 /*output=*/output));
1365 return Status::OK();
1366 }
1367
1368 return Unimplemented("Unsupported batch norm operation");
1369 }
1370
1371 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
EmitCholeskyThunk(mlir::Operation * op)1372 Status IrEmitterUnnested::EmitCholeskyThunk(mlir::Operation* op) {
1373 auto cholesky_op = mlir::cast<mlir::lmhlo_gpu::CholeskyOp>(op);
1374
1375 const Shape shape = GetShape(cholesky_op.input());
1376 int ndim = shape.dimensions_size();
1377 CHECK_GE(ndim, 2);
1378 int64_t n = shape.dimensions(ndim - 1);
1379
1380 const auto& dims = shape.dimensions();
1381 int64_t batch_size =
1382 std::accumulate(dims.begin(), dims.end() - 2, int64{1},
1383 [](int64_t a, int64_t b) { return a * b; });
1384
1385 TF_ASSIGN_OR_RETURN(auto operand_buffer,
1386 GetAllocationSlice(cholesky_op.input()));
1387 TF_ASSIGN_OR_RETURN(auto a_buffer, GetAllocationSlice(cholesky_op.output()));
1388 TF_ASSIGN_OR_RETURN(auto workspace_buffer,
1389 GetAllocationSlice(cholesky_op.scratch()));
1390 TF_ASSIGN_OR_RETURN(auto info_buffer, GetAllocationSlice(cholesky_op.info()));
1391
1392 ThunkSequence thunks;
1393
1394 if (operand_buffer != a_buffer) {
1395 thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
1396 GetThunkInfo(op),
1397 /*source_address=*/operand_buffer,
1398 /*destination_buffer=*/a_buffer,
1399 /*mem_size=*/ShapeUtil::ByteSizeOf(shape)));
1400 }
1401
1402 CholeskyOptions options;
1403 options.set_lower(cholesky_op.is_lower());
1404 thunks.push_back(absl::make_unique<CholeskyThunk>(
1405 GetThunkInfo(op), options, a_buffer, workspace_buffer, info_buffer,
1406 shape.element_type(), batch_size, n));
1407
1408 // Elide the sequential thunk if there's no copy.
1409 if (thunks.size() == 1) {
1410 AddThunkToThunkSequence(std::move(thunks[0]));
1411 } else {
1412 AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
1413 GetThunkInfo(op), std::move(thunks)));
1414 }
1415
1416 return Status::OK();
1417 }
1418 #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
1419
EmitCustomCallThunk(mlir::Operation * op)1420 Status IrEmitterUnnested::EmitCustomCallThunk(mlir::Operation* op) {
1421 auto custom_call = mlir::cast<mlir::lmhlo::CustomCallOp>(op);
1422 const std::string call_target_name = custom_call.call_target_name().str();
1423
1424 void* call_target = CustomCallTargetRegistry::Global()->Lookup(
1425 call_target_name, std::string(platform_name()));
1426 if (!call_target) {
1427 return Unimplemented(
1428 "No registered implementation for custom call to \"%s\"",
1429 call_target_name);
1430 }
1431
1432 std::vector<CustomCallThunk::OptionalSlice> operands;
1433 std::vector<CustomCallThunk::OptionalSlice> results;
1434
1435 if (custom_call.target_arg_mapping()) {
1436 auto values_to_slices_with_token_holes =
1437 [&](mlir::ValueRange operands, mlir::ArrayAttr op_to_target_mapping,
1438 mlir::IntegerAttr num_target)
1439 -> StatusOr<std::vector<CustomCallThunk::OptionalSlice>> {
1440 std::vector<CustomCallThunk::OptionalSlice> slices(num_target.getInt());
1441 for (auto index_and_value_it :
1442 llvm::zip(op_to_target_mapping, operands)) {
1443 mlir::Attribute index_attr = std::get<0>(index_and_value_it);
1444 mlir::Value value = std::get<1>(index_and_value_it);
1445 int64_t index = index_attr.cast<mlir::IntegerAttr>().getInt();
1446 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
1447 GetAllocationSlice(value));
1448 slices[index] = slice;
1449 }
1450 return slices;
1451 };
1452
1453 mlir::lmhlo::CustomCallTargetArgMapping target_mapping =
1454 *custom_call.target_arg_mapping();
1455 TF_ASSIGN_OR_RETURN(
1456 operands, values_to_slices_with_token_holes(
1457 custom_call.args(), target_mapping.args_to_target_args(),
1458 target_mapping.num_args()));
1459 TF_ASSIGN_OR_RETURN(results, values_to_slices_with_token_holes(
1460 custom_call.output(),
1461 target_mapping.results_to_target_results(),
1462 target_mapping.num_results()));
1463 } else {
1464 auto values_to_slices = [&](mlir::ValueRange values)
1465 -> StatusOr<std::vector<CustomCallThunk::OptionalSlice>> {
1466 std::vector<CustomCallThunk::OptionalSlice> slices;
1467 for (mlir::Value value : values) {
1468 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice slice,
1469 GetAllocationSlice(value));
1470 slices.push_back(slice);
1471 }
1472 return slices;
1473 };
1474
1475 TF_ASSIGN_OR_RETURN(operands, values_to_slices(custom_call.args()));
1476 TF_ASSIGN_OR_RETURN(results, values_to_slices(custom_call.output()));
1477 }
1478
1479 CustomCallThunk::CustomCallTarget custom_call_target;
1480
1481 // For information about this calling convention, see
1482 // xla/g3doc/custom_call.md.
1483 switch (custom_call.api_version()) {
1484 case mlir::mhlo::CustomCallApiVersion::API_VERSION_ORIGINAL:
1485 using original_call_type =
1486 void (*)(CustomCallThunk::Stream /*stream*/, void** /*buffers*/,
1487 const char* /*opaque*/, size_t /*opaque_len*/);
1488 custom_call_target = [call_target](CustomCallThunk::Stream stream,
1489 void** buffers, const char* opaque,
1490 size_t opaque_len,
1491 XlaCustomCallStatus*) {
1492 auto typed_call_target =
1493 reinterpret_cast<original_call_type>(call_target);
1494 typed_call_target(stream, buffers, opaque, opaque_len);
1495 };
1496 break;
1497 case mlir::mhlo::CustomCallApiVersion::API_VERSION_STATUS_RETURNING:
1498 using status_returning_call_type =
1499 void (*)(CustomCallThunk::Stream /*stream*/, void** /*buffers*/,
1500 const char* /*opaque*/, size_t /*opaque_len*/,
1501 XlaCustomCallStatus* /*status*/);
1502 custom_call_target =
1503 reinterpret_cast<status_returning_call_type>(call_target);
1504 break;
1505 default:
1506 return InternalError("Unknown custom-call API version enum value: %d",
1507 custom_call.api_version());
1508 }
1509
1510 AddThunkToThunkSequence(absl::make_unique<CustomCallThunk>(
1511 GetThunkInfo(op), std::move(custom_call_target), std::move(operands),
1512 std::move(results), custom_call.backend_config().str()));
1513 return Status::OK();
1514 }
1515
EmitFftThunk(mlir::Operation * op)1516 Status IrEmitterUnnested::EmitFftThunk(mlir::Operation* op) {
1517 auto fft_op = mlir::cast<mlir::lmhlo::FftOp>(op);
1518 const Shape operand_shape = GetShape(fft_op.operand());
1519 const Shape output_shape = GetShape(fft_op.output());
1520 TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(operand_shape.layout()));
1521 TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(output_shape.layout()));
1522
1523 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice arg_slice,
1524 GetAllocationSlice(fft_op.operand()));
1525 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice dest_slice,
1526 GetAllocationSlice(fft_op.output()));
1527 TF_ASSIGN_OR_RETURN(xla::FftType fft_type, ConvertFftType(fft_op.fft_type()));
1528 auto fft_length_values = fft_op.fft_length().getValues<int64>();
1529 std::vector<int64> fft_length(fft_length_values.begin(),
1530 fft_length_values.end());
1531 AddThunkToThunkSequence(
1532 absl::make_unique<FftThunk>(GetThunkInfo(op), fft_type, fft_length,
1533 /*input_buffer=*/arg_slice,
1534 /*output_buffer=*/dest_slice,
1535 /*input_shape=*/operand_shape,
1536 /*output_shape=*/output_shape));
1537 return Status::OK();
1538 }
1539
EmitTriangularSolve(mlir::Operation * op)1540 Status IrEmitterUnnested::EmitTriangularSolve(mlir::Operation* op) {
1541 auto triangular_solve_op = mlir::cast<mlir::lmhlo::TriangularSolveOp>(op);
1542 auto has_fortran_layout = [](mlir::DenseIntElementsAttr layout_attr) {
1543 int64_t n = layout_attr.getNumElements();
1544 return layout_attr.getValue<int64_t>({0}) == n - 2 &&
1545 layout_attr.getValue<int64_t>({1}) == n - 1;
1546 };
1547 TF_RET_CHECK(has_fortran_layout(triangular_solve_op.layout_a()));
1548 TF_RET_CHECK(has_fortran_layout(triangular_solve_op.layout_b()));
1549 TF_RET_CHECK(has_fortran_layout(triangular_solve_op.layout_output()));
1550
1551 const Shape b_shape = GetShape(triangular_solve_op.b());
1552
1553 const Shape output_shape = GetShape(triangular_solve_op.output());
1554
1555 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice a_slice,
1556 GetAllocationSlice(triangular_solve_op.a()));
1557 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice b_slice,
1558 GetAllocationSlice(triangular_solve_op.b()));
1559 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
1560 GetAllocationSlice(triangular_solve_op.output()));
1561 TF_ASSIGN_OR_RETURN(TriangularSolveOptions_Transpose transpose_a,
1562 ConvertTranspose(triangular_solve_op.transpose_a()));
1563
1564 ThunkSequence thunks;
1565
1566 // Triangular solve is in-place on 'b', so copy 'b' to the output if they
1567 // aren't the same buffer.
1568 if (b_slice != output_slice) {
1569 thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
1570 Thunk::ThunkInfo(),
1571 /*source_address=*/b_slice,
1572 /*destination_buffer=*/output_slice,
1573 /*mem_size=*/ShapeUtil::ByteSizeOf(b_shape)));
1574 }
1575
1576 int64_t m = b_shape.dimensions(b_shape.rank() - 2);
1577 int64_t n = b_shape.dimensions(b_shape.rank() - 1);
1578 int64_t batch_size = std::accumulate(
1579 b_shape.dimensions().begin(), b_shape.dimensions().end() - 2, int64{1},
1580 [](int64_t a, int64_t b) { return a * b; });
1581 int64_t elem_size =
1582 ShapeUtil::ByteSizeOfPrimitiveType(output_shape.element_type());
1583 int64_t a_batch_stride =
1584 triangular_solve_op.left_side() ? m * m * elem_size : n * n * elem_size;
1585 int64_t b_batch_stride = m * n * elem_size;
1586 TriangularSolveOptions options;
1587 options.set_left_side(triangular_solve_op.left_side());
1588 options.set_lower(triangular_solve_op.lower());
1589 options.set_unit_diagonal(triangular_solve_op.unit_diagonal());
1590 options.set_transpose_a(transpose_a);
1591 thunks.push_back(absl::make_unique<TriangularSolveThunk>(
1592 GetThunkInfo(op), options,
1593 /*a_input_buffer=*/a_slice,
1594 /*b_input_buffer=*/output_slice, output_shape.element_type(), batch_size,
1595 m, n, a_batch_stride, b_batch_stride));
1596
1597 // Elide the sequential thunk if there's no copy.
1598 if (thunks.size() == 1) {
1599 AddThunkToThunkSequence(std::move(thunks[0]));
1600 } else {
1601 AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
1602 GetThunkInfo(op), std::move(thunks)));
1603 }
1604 return Status::OK();
1605 }
1606
1607 // Convert the following form of fusion region:
1608 // fusion() {
1609 // %0 = tensor_load %external_memref0
1610 // %1 = tensor_load %external_memref1
1611 // ...
1612 // tensor_store %ret, %external_memref2
1613 // }
1614 // to
1615 // fusion(%external_memref0, %external_memref1) (^bb(%0, %1) {
1616 // ...
1617 // mhlo.return %ret
1618 // })
1619 //
1620 // So that it's suitable for MHLO -> XLA HLO conversion.
1621 // This function won't be needed once ElementalIrEmitter migrates to take MHLO
1622 // instead.
ProcessFusionForConversion(mlir::Region * region,std::vector<Shape> * operand_shapes,std::vector<Shape> * output_shapes)1623 static Status ProcessFusionForConversion(mlir::Region* region,
1624 std::vector<Shape>* operand_shapes,
1625 std::vector<Shape>* output_shapes) {
1626 std::vector<mlir::memref::TensorLoadOp> loads;
1627 std::vector<mlir::memref::TensorStoreOp> stores;
1628
1629 region->walk([&](mlir::memref::TensorLoadOp load) {
1630 if (load.memref().getParentRegion() != region) {
1631 loads.push_back(load);
1632 }
1633 });
1634
1635 region->walk([&](mlir::memref::TensorStoreOp store) {
1636 if (store.memref().getParentRegion() != region) {
1637 stores.push_back(store);
1638 }
1639 });
1640
1641 for (auto load : loads) {
1642 auto arg = region->addArgument(load.getType());
1643 load.replaceAllUsesWith(arg);
1644 Shape shape = GetShape(load.getResult());
1645 operand_shapes->push_back(std::move(shape));
1646 load.erase();
1647 }
1648
1649 std::vector<mlir::Value> returned_values;
1650 for (auto store : stores) {
1651 Shape shape = GetShape(store.memref());
1652 output_shapes->push_back(shape);
1653
1654 returned_values.push_back(store.tensor());
1655 store.erase();
1656 }
1657
1658 region->back().back().erase();
1659 auto b = mlir::OpBuilder::atBlockEnd(®ion->back());
1660 auto loc = returned_values[0].getLoc();
1661 b.create<mlir::mhlo::ReturnOp>(loc, returned_values);
1662 return Status::OK();
1663 }
1664
1665 // TODO(timshen): update the comment once the HandleFusion code path deleted.
1666 //
1667 // This is migrated from IrEmitter::HandleFusion() with IrEmitterUnnested as the
1668 // subclass. The logic is de-virtualized and less scattered.
EmitLoopFusion(mlir::Operation * op)1669 Status IrEmitterUnnested::EmitLoopFusion(mlir::Operation* op) {
1670 auto fusion = mlir::cast<mlir::lmhlo::FusionOp>(op);
1671 MlirEmitterContext context;
1672 context.SetOperation(fusion);
1673
1674 TF_ASSIGN_OR_RETURN(const HloComputation* fused_computation,
1675 GetOrCreateSubComputationFromRegion(&fusion.region(),
1676 /*is_fusion=*/true));
1677
1678 int unroll_factor;
1679 if (!MayPreventVectorization(fusion)) {
1680 unroll_factor = ComputeMaxUnrollFactor(fusion, hlo_module_config_);
1681 } else {
1682 unroll_factor = 1;
1683 }
1684
1685 bool row_vectorized;
1686 int num_big_inputs;
1687 std::tie(row_vectorized, num_big_inputs) = RowVectorizationEnabled(fusion);
1688 bool few_waves = [fusion, row_vectorized, num_big_inputs]() mutable {
1689 for (mlir::Operation& op : fusion.region().front()) {
1690 if (mlir::isa<mlir::memref::TensorLoadOp, mlir::memref::TensorStoreOp,
1691 mlir::lmhlo::TerminatorOp, mlir::mhlo::ReturnOp,
1692 mlir::mhlo::ConstOp>(op)) {
1693 continue;
1694 }
1695 HloOpcode opcode = *MhloToHloOpcode(&op);
1696 if (HloInstruction::IsOpElementwise(opcode)) {
1697 continue;
1698 }
1699 if (auto broadcast = mlir::dyn_cast<mlir::mhlo::BroadcastInDimOp>(op)) {
1700 if (broadcast.broadcast_dimensions().empty() ||
1701 // More then 2 bit inputs cause one speed regression.
1702 (row_vectorized && num_big_inputs <= 2)) {
1703 continue;
1704 }
1705 }
1706 VLOG(2) << "few_waves not enabled due to: " << MlirToString(&op);
1707 return false;
1708 }
1709 return true;
1710 }();
1711
1712 Shape element_shape = context.output_shapes[0];
1713 LaunchDimensionsConfig launch_config{unroll_factor, few_waves,
1714 row_vectorized};
1715 // Check that the shapes is supported.
1716 if (launch_config.row_vectorized &&
1717 ThreadsPerBlockRowVectorized(element_shape,
1718 ir_emitter_context_->gpu_device_info(),
1719 launch_config) <= 0) {
1720 VLOG(2) << "Cancelling row_vectorization as the shape isn't supported.";
1721 launch_config.row_vectorized = false;
1722 launch_config.few_waves = false;
1723 }
1724
1725 TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions,
1726 CalculateLaunchDimensions(
1727 element_shape, ir_emitter_context_->gpu_device_info(),
1728 launch_config));
1729
1730 std::vector<llvm_ir::IrArray> ir_arrays;
1731 Thunk* kernel_thunk;
1732 {
1733 TF_ASSIGN_OR_RETURN(std::unique_ptr<KernelThunk> kernel_thunk_ptr,
1734 BuildKernelThunk(fusion, GetThunkInfo(op), &ir_arrays,
1735 launch_dimensions));
1736 kernel_thunk = kernel_thunk_ptr.get();
1737 thunk_sequence_.emplace_back(std::move(kernel_thunk_ptr));
1738 }
1739
1740 auto operand_arrays =
1741 absl::MakeSpan(ir_arrays).subspan(0, context.operand_shapes.size());
1742 auto output_element_arrays = absl::MakeSpan(ir_arrays).subspan(
1743 context.operand_shapes.size(), context.output_shapes.size());
1744
1745 GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_,
1746 GetNestedComputer());
1747 FusedIrEmitter fused_emitter(&elemental_emitter);
1748
1749 for (int i = 0; i < context.operand_shapes.size(); i++) {
1750 auto* builder = &b_;
1751 auto ir_array = operand_arrays[i];
1752 fused_emitter.BindGenerator(
1753 fused_computation->parameter_instruction(i),
1754 [builder, ir_array](llvm_ir::IrArray::Index index) {
1755 return ir_array.EmitReadArrayElement(index, builder);
1756 });
1757 }
1758 TF_ASSIGN_OR_RETURN(
1759 auto element_generator,
1760 fused_emitter.GetGenerator(fused_computation->root_instruction()));
1761
1762 llvm::Type* index_type =
1763 GetIndexTypeForKernel(fusion, launch_dimensions.launch_bound(), &b_);
1764
1765 if (context.output_shapes.size() > 1) {
1766 // For multioutput fusion, we need to emit each operand and the root.
1767 TF_RETURN_IF_ERROR(
1768 ParallelLoopEmitter(element_generator, output_element_arrays,
1769 launch_dimensions, &b_, launch_config)
1770 .EmitLoop(context.name, index_type));
1771 } else {
1772 TF_RETURN_IF_ERROR(
1773 ParallelLoopEmitter(element_generator, output_element_arrays[0],
1774 launch_dimensions, &b_, launch_config)
1775 .EmitLoop(context.name, index_type));
1776 }
1777
1778 b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator());
1779 return Status::OK();
1780 }
1781
EmitFusion(mlir::Operation * op)1782 Status IrEmitterUnnested::EmitFusion(mlir::Operation* op) {
1783 auto fusion_op = mlir::cast<mlir::lmhlo::FusionOp>(op);
1784 const bool is_single_instruction = IsSingleInstructionFusion(fusion_op);
1785
1786 // Infer the layout of fusion internal nodes.
1787 const FusionLayoutAnalysis layout_analysis(fusion_op);
1788
1789 auto fusion_results = fusion_op.getFusionResults();
1790 TF_RET_CHECK(!fusion_results.empty());
1791 if (fusion_results.size() > 1) {
1792 // In the case of root tuple, it can be either reduce or slice input
1793 // fusion.
1794 if (IsInputFusibleSlices(op, /*verify_no_strides=*/true)) {
1795 // The emitter doesn't support all cases. If it's not supported, fallback
1796 // to ElementalIrEmitter.
1797 auto status = EmitInputFusibleNonStridedSlices(op);
1798 if (status.code() == tensorflow::error::FAILED_PRECONDITION) {
1799 return EmitLoopFusion(op);
1800 }
1801 return status;
1802 }
1803
1804 const bool is_parallel_reduce =
1805 absl::c_any_of(fusion_results, [&layout_analysis](mlir::Value result) {
1806 mlir::Operation* maybe_reduce = result.getDefiningOp();
1807 return maybe_reduce->getNumResults() == 1 &&
1808 IsReductionFromOrToContiguousDimensions(maybe_reduce,
1809 layout_analysis);
1810 });
1811
1812 if (is_parallel_reduce) {
1813 return EmitUnnestedReduction(op, layout_analysis);
1814 }
1815 }
1816
1817 mlir::Operation* fusion_root = fusion_results[0].getDefiningOp();
1818 if (mlir::isa<mlir::mhlo::ScatterOp>(fusion_root)) {
1819 TF_ASSIGN_OR_RETURN(
1820 const HloComputation* fused_computation,
1821 GetOrCreateSubComputationFromRegion(&fusion_op.region(),
1822 /*is_fusion=*/true));
1823 auto* root = fused_computation->root_instruction();
1824
1825 ThunkSequence thunks;
1826 // The initialization from 'operand' is using different loop bounds, so
1827 // emit it in a separate kernel. Treat it like a loop fusion, writing to
1828 // the output buffer.
1829 {
1830 auto unroll_factor =
1831 ComputeMaxUnrollFactor(fusion_op, hlo_module_config_);
1832 const Shape& element_shape = root->shape();
1833 TF_ASSIGN_OR_RETURN(
1834 LaunchDimensions launch_dimensions,
1835 CalculateLaunchDimensions(element_shape,
1836 ir_emitter_context_->gpu_device_info(),
1837 {unroll_factor, /*few_waves=*/false}));
1838
1839 std::vector<llvm_ir::IrArray> ir_arrays;
1840 TF_ASSIGN_OR_RETURN(auto operand_thunk,
1841 BuildKernelThunk(op, Thunk::ThunkInfo(), &ir_arrays,
1842 launch_dimensions));
1843 thunks.push_back(std::move(operand_thunk));
1844
1845 GpuElementalIrEmitter operand_elemental_emitter(
1846 hlo_module_config_, ir_emitter_context_->llvm_module(), &b_,
1847 GetNestedComputer());
1848 FusedIrEmitter operand_fused_emitter(&operand_elemental_emitter);
1849 for (int i = 0; i < fused_computation->num_parameters(); i++) {
1850 auto fused_operand = fused_computation->parameter_instruction(i);
1851 operand_fused_emitter.BindGenerator(
1852 fused_operand, [this, &ir_arrays, i,
1853 fused_operand](llvm_ir::IrArray::Index index) {
1854 return ir_arrays[i].EmitReadArrayElement(index, &b_,
1855 fused_operand->name());
1856 });
1857 }
1858 TF_ASSIGN_OR_RETURN(auto generator,
1859 operand_fused_emitter.GetGenerator(root->operand(0)));
1860
1861 TF_RETURN_IF_ERROR(
1862 ParallelLoopEmitter(generator, ir_arrays.back(), launch_dimensions,
1863 &b_, {unroll_factor})
1864 .EmitLoop(IrName(mlir::GetNameFromLoc(fusion_op.getLoc())),
1865 GetIndexTypeForKernel(
1866 fusion_op, launch_dimensions.launch_bound(), &b_)));
1867 }
1868
1869 // Now build the actual scatter, reading and writing to the freshly
1870 // filled output buffer.
1871 {
1872 const Shape& updates_shape = root->operand(2)->shape();
1873 TF_ASSIGN_OR_RETURN(
1874 LaunchDimensions launch_dimensions,
1875 CalculateLaunchDimensions(updates_shape,
1876 ir_emitter_context_->gpu_device_info()));
1877 std::vector<llvm_ir::IrArray> ir_arrays;
1878 TF_ASSIGN_OR_RETURN(auto scatter_thunk,
1879 BuildKernelThunk(op, Thunk::ThunkInfo(), &ir_arrays,
1880 launch_dimensions));
1881 thunks.push_back(std::move(scatter_thunk));
1882 // Spin up a new fused emitter for the scatter kernel and emit it.
1883 GpuElementalIrEmitter scatter_elemental_emitter(
1884 hlo_module_config_, ir_emitter_context_->llvm_module(), &b_,
1885 GetNestedComputer());
1886 FusedIrEmitter scatter_fused_emitter(&scatter_elemental_emitter);
1887 for (int i = 0; i < fused_computation->num_parameters(); i++) {
1888 auto fused_operand = fused_computation->parameter_instruction(i);
1889 scatter_fused_emitter.BindGenerator(
1890 fused_operand, [this, &ir_arrays, i,
1891 fused_operand](llvm_ir::IrArray::Index index) {
1892 return ir_arrays[i].EmitReadArrayElement(index, &b_,
1893 fused_operand->name());
1894 });
1895 }
1896
1897 TF_ASSIGN_OR_RETURN(const auto dim_numbers,
1898 mlir::LhloDialectEmitter::GetScatterDimensionNumbers(
1899 root, fusion_op.getContext()));
1900
1901 ScatterDescriptor desc;
1902 desc.name = IrName(root);
1903 desc.operand_shape = root->operand(0)->shape();
1904 desc.scatter_indices_shape = root->operand(1)->shape();
1905 desc.updates_shape = updates_shape;
1906 desc.dim_numbers = dim_numbers;
1907 desc.unique_indices = root->unique_indices();
1908 desc.update_computation = root->called_computations()[0];
1909 desc.output = ir_arrays.back();
1910 TF_ASSIGN_OR_RETURN(desc.scatter_indices_gen,
1911 scatter_fused_emitter.GetGenerator(root->operand(1)));
1912 TF_ASSIGN_OR_RETURN(desc.updates_gen,
1913 scatter_fused_emitter.GetGenerator(root->operand(2)));
1914 desc.get_index_type = [&](int64_t launch_size) {
1915 return GetIndexTypeForKernel(root, launch_size, &b_);
1916 };
1917
1918 TF_RETURN_IF_ERROR(
1919 EmitScatter(desc, thunks.back().get(), launch_dimensions));
1920 }
1921 AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
1922 GetThunkInfo(op), std::move(thunks)));
1923 return Status::OK();
1924 }
1925
1926 // HandleFusion specializes reduction from a multi-dimensional array to
1927 // a 1D array. The specialized version requires a initializer thunk that
1928 // initializes the output array to the initial value of the reduce.
1929 if (mlir::isa<mlir::mhlo::ReduceOp>(fusion_root) &&
1930 fusion_root->getNumResults() == 1 &&
1931 IsReductionFromOrToContiguousDimensions(fusion_root, layout_analysis)) {
1932 return EmitUnnestedReduction(op, layout_analysis);
1933 }
1934
1935 if (!is_single_instruction &&
1936 CanEmitFusedDynamicUpdateSliceInPlaceForGpu(
1937 fusion_op, ir_emitter_context_->allocations())) {
1938 // Fusion node with dynamic-update-slice as the root where the op's input
1939 // (i.e. array to update) shares the same slice as its output. In this case
1940 // we have a special algorithm that modifies the output in place without
1941 // touching the un-updated elements.
1942 CHECK_EQ(1, GetHloOutputs(op).size());
1943
1944 TF_ASSIGN_OR_RETURN(
1945 const HloComputation* fused_computation,
1946 GetOrCreateSubComputationFromRegion(&fusion_op.region(),
1947 /*is_fusion=*/true));
1948
1949 // Shape of the dynamic-update-slice's "update" operand.
1950 Shape update_shape =
1951 fused_computation->root_instruction()->operand(1)->shape();
1952
1953 TF_ASSIGN_OR_RETURN(
1954 LaunchDimensions launch_dimensions,
1955 CalculateLaunchDimensions(update_shape,
1956 ir_emitter_context_->gpu_device_info()));
1957
1958 // Set up kernel thunk and fused ir emitter.
1959 std::vector<llvm_ir::IrArray> ir_arrays;
1960 TF_ASSIGN_OR_RETURN(auto fusion_thunk,
1961 BuildKernelThunk(fusion_op, GetThunkInfo(op),
1962 &ir_arrays, launch_dimensions));
1963 AddThunkToThunkSequence(std::move(fusion_thunk));
1964
1965 GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
1966 ir_emitter_context_->llvm_module(),
1967 &b_, GetNestedComputer());
1968
1969 FusedIrEmitter fused_emitter(&elemental_emitter);
1970
1971 for (int i = 0; i < fused_computation->num_parameters(); i++) {
1972 auto fused_operand = fused_computation->parameter_instruction(i);
1973 fused_emitter.BindGenerator(
1974 fused_operand, [this, &ir_arrays, i,
1975 fused_operand](const llvm_ir::IrArray::Index& index) {
1976 return ir_arrays[i].EmitReadArrayElement(index, &b_,
1977 fused_operand->name());
1978 });
1979 }
1980
1981 // Array to write into. Because this is an in-place operation, this is the
1982 // same as operand 0's array.
1983 const IrArray& output_array = ir_arrays.back();
1984
1985 return llvm_ir::EmitParallelFusedDynamicUpdateSliceInPlace(
1986 fused_computation, output_array, &fused_emitter, launch_dimensions,
1987 &b_);
1988 }
1989
1990 if (auto copy = mlir::dyn_cast<mlir::mhlo::CopyOp>(fusion_root)) {
1991 if (IsSingleInstructionFusion(fusion_op)) {
1992 auto operands = GetHloOperands(fusion_op);
1993 auto outputs = GetHloOutputs(fusion_op);
1994 TF_RET_CHECK(operands.size() == 1);
1995 TF_RET_CHECK(outputs.size() == 1);
1996
1997 auto operand_shape = GetShape(operands[0]);
1998 auto output_shape = GetShape(outputs[0]);
1999
2000 CHECK(ShapeUtil::Compatible(operand_shape, output_shape));
2001 auto maybe_slice = GetAllocationSlice(operands[0]);
2002 if (LayoutUtil::Equal(operand_shape.layout(), output_shape.layout()) &&
2003 maybe_slice.ok()) {
2004 // Copy the operand into the output if it's not the same buffer already.
2005 auto operand_buffer = *maybe_slice;
2006 auto destination_buffer = *GetAllocationSlice(outputs[0]);
2007 if (operand_buffer != destination_buffer) {
2008 AddThunkToThunkSequence(absl::make_unique<DeviceToDeviceCopyThunk>(
2009 GetThunkInfo(op),
2010 /*source_address=*/operand_buffer,
2011 /*destination_buffer=*/destination_buffer,
2012 /*mem_size=*/
2013 ByteSizeOf(operand_shape)));
2014 }
2015 return Status::OK();
2016 }
2017 }
2018 }
2019
2020 TF_ASSIGN_OR_RETURN(const bool matched_021, CheckAndEmitHloWithTile021(op));
2021 if (matched_021) {
2022 return Status::OK();
2023 }
2024
2025 return EmitLoopFusion(op);
2026 }
2027
EmitExtraOutputsForReduce(absl::Span<const llvm_ir::IrArray> result_ir_arrays,const IrArray::Index & index,bool use_linear_index,absl::Span<const std::pair<llvm_ir::ElementGenerator,int>> extra_output_gens)2028 Status IrEmitterUnnested::EmitExtraOutputsForReduce(
2029 absl::Span<const llvm_ir::IrArray> result_ir_arrays,
2030 const IrArray::Index& index, bool use_linear_index,
2031 absl::Span<const std::pair<llvm_ir::ElementGenerator, int>>
2032 extra_output_gens) {
2033 // Compute all extra output values before writing them. This avoids
2034 // overwriting aliased input/output buffers before all reads occured.
2035 absl::InlinedVector<llvm::Value*, 8> extra_output_ir_values;
2036 for (int i = 0; i < extra_output_gens.size(); ++i) {
2037 TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value,
2038 extra_output_gens[i].first(index));
2039 extra_output_ir_values.push_back(extra_output_ir_value);
2040 }
2041 for (int i = 0; i < extra_output_gens.size(); ++i) {
2042 result_ir_arrays[extra_output_gens[i].second].EmitWriteArrayElement(
2043 index, extra_output_ir_values[i], &b_, use_linear_index);
2044 }
2045 return Status::OK();
2046 }
2047
AssertNonDeterminismIsOkay(const string & op_name)2048 Status IrEmitterUnnested::AssertNonDeterminismIsOkay(const string& op_name) {
2049 if (hlo_module_config_.debug_options().xla_gpu_deterministic_ops()) {
2050 return Unimplemented(
2051 "HLO instruction %s does not have a deterministic implementation, "
2052 "but run-to-run determinism is required by "
2053 "--xla_gpu_deterministic_ops.",
2054 op_name);
2055 }
2056 return Status::OK();
2057 }
2058
EmitSelectAndScatter(mlir::Operation * op)2059 Status IrEmitterUnnested::EmitSelectAndScatter(mlir::Operation* op) {
2060 auto select_and_scatter_op = mlir::cast<mlir::lmhlo::SelectAndScatterOp>(op);
2061
2062 const Shape source_shape = GetShape(select_and_scatter_op.source());
2063 const Shape operand_shape = GetShape(select_and_scatter_op.operand());
2064 const int64_t rank = operand_shape.rank();
2065
2066 CHECK_EQ(rank, source_shape.rank());
2067 if (select_and_scatter_op.window_dimensions()) {
2068 CHECK_EQ(rank, select_and_scatter_op.window_dimensions()->size());
2069 }
2070
2071 TF_RETURN_IF_ERROR(AssertNonDeterminismIsOkay(
2072 mlir::GetNameFromLoc(select_and_scatter_op.getLoc())));
2073
2074 std::string name = mlir::GetNameFromLoc(select_and_scatter_op.getLoc());
2075
2076 // IrEmitterUnnested implements kSelectAndScatter as a SequentialThunk
2077 // consisting of two thunks, an initializer KernelThunk that initializes
2078 // the output and another KernelThunk that accumulates the scattered
2079 // elements.
2080 ThunkSequence thunks;
2081 thunks.emplace_back();
2082 TF_ASSIGN_OR_RETURN(thunks.back(), BuildInitializerThunk(
2083 op, select_and_scatter_op.init_value(),
2084 select_and_scatter_op.out()));
2085
2086 TF_ASSIGN_OR_RETURN(
2087 LaunchDimensions launch_dimensions,
2088 CalculateLaunchDimensions(source_shape,
2089 ir_emitter_context_->gpu_device_info()));
2090 std::vector<llvm_ir::IrArray> ir_arrays;
2091 thunks.emplace_back();
2092 // Init value is not needed in IR emission.
2093 TF_ASSIGN_OR_RETURN(
2094 thunks.back(),
2095 BuildKernelThunk(
2096 select_and_scatter_op,
2097 {select_and_scatter_op.operand(), select_and_scatter_op.source(),
2098 select_and_scatter_op.out()},
2099 Thunk::ThunkInfo(), &ir_arrays, launch_dimensions));
2100
2101 CHECK_EQ(ir_arrays.size(), 3);
2102 const IrArray& operand_array = ir_arrays[0];
2103 const IrArray& source_array = ir_arrays[1];
2104 const IrArray& out_array = ir_arrays[2];
2105
2106 auto select_and_scatter_thunk =
2107 absl::make_unique<SequentialThunk>(GetThunkInfo(op), std::move(thunks));
2108
2109 llvm::Type* index_type = GetIndexTypeForKernel(
2110 select_and_scatter_op, launch_dimensions.launch_bound(), &b_);
2111 auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
2112 return llvm::ConstantInt::get(index_type, c);
2113 };
2114
2115 // kSelectAndScatter is implemented as two kernel launches: the first launch
2116 // initializes the output array to the given initial value,
2117 // and the second accumulates the "source" matrix to the
2118 // selected elements in the output array. The first launch is already
2119 // implemented by the initializer thunk generated earlier, so this function
2120 // only needs to take care of the select-and-scatter part.
2121 //
2122 // Pseudo code for select-and-scatter:
2123 //
2124 // for (coordinates S in the source): # This loop is parallel.
2125 // initialized_flag = false
2126 // for (coordinates W in the window):
2127 // I = S * stride + W - pad_low
2128 // if I within bounds of operand:
2129 // if !(initialized_flag and select(selected_value, operand(I))):
2130 // selected_value = operand(I)
2131 // selected_index = I
2132 // initialized_flag = true
2133 // output(selected_index) = scatter(output(selected_index), source(S))
2134 auto loop_body_emitter = [&](const IrArray::Index& source_index) -> Status {
2135 // Allocate space to keep the currently selected value, its index, and a
2136 // boolean flag if the value is initialized. The initialized_flag is set
2137 // false.
2138 llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry(
2139 llvm_ir::PrimitiveTypeToIrType(operand_shape.element_type(),
2140 ir_emitter_context_->llvm_module()),
2141 "selected_value_address", &b_);
2142
2143 llvm::Value* selected_index_address =
2144 llvm_ir::EmitAllocaAtFunctionEntryWithCount(
2145 index_type, index_typed_constant(rank), "selected_index_address",
2146 &b_);
2147
2148 llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry(
2149 b_.getInt1Ty(), "initialized_flag_address", &b_);
2150 Store(b_.getInt1(false), initialized_flag_address);
2151
2152 // Create the inner loop to iterate over the window.
2153 llvm_ir::ForLoopNest window_loops(absl::StrCat(name, "inner"), &b_,
2154 index_type);
2155
2156 DimensionVector window_size;
2157 mlir::DenseIntElementsAttr window_dimensions =
2158 select_and_scatter_op.window_dimensions().getValue();
2159 for (const auto& dim : window_dimensions) {
2160 window_size.push_back(dim.getSExtValue());
2161 CHECK_GT(dim.getSExtValue(), 0);
2162 }
2163
2164 const IrArray::Index window_index = window_loops.AddLoopsForShape(
2165 ShapeUtil::MakeShape(operand_shape.element_type(), window_size),
2166 "window");
2167 llvm_ir::SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(),
2168 &b_);
2169
2170 // Compute the operand index to visit and evaluate the condition whether the
2171 // operand index is within the bounds. The unsigned comparison includes
2172 // checking whether the operand index >= 0.
2173 std::vector<llvm::Value*> operand_multi_index(source_index.size());
2174 llvm::Value* in_bounds_condition = b_.getInt1(true);
2175
2176 auto strides = *select_and_scatter_op.window_strides();
2177 auto paddings = *select_and_scatter_op.padding();
2178
2179 for (auto stride_and_padding :
2180 llvm::enumerate(llvm::zip(strides, paddings))) {
2181 const int i = stride_and_padding.index();
2182 int64_t stride = std::get<0>(stride_and_padding.value()).getSExtValue();
2183 int64_t padding = std::get<1>(stride_and_padding.value()).getSExtValue();
2184
2185 llvm::Value* strided_index =
2186 NSWMul(source_index[i], index_typed_constant(stride));
2187 operand_multi_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]),
2188 index_typed_constant(padding));
2189 llvm::Value* index_condition = ICmpULT(
2190 operand_multi_index[i],
2191 index_typed_constant(ShapeUtil::GetDimension(operand_shape, i)));
2192 in_bounds_condition = And(in_bounds_condition, index_condition);
2193 }
2194
2195 // Only need to do something if the operand index is within the bounds.
2196 // First check if the initialized_flag is set.
2197 llvm_ir::LlvmIfData if_in_bounds =
2198 llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
2199 llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, &b_);
2200 llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse(
2201 Load(initialized_flag_address), "initialized", &b_);
2202
2203 // If the initialized_flag is false, initialize the selected value and index
2204 // with the currently visiting operand.
2205 llvm_ir::SetToFirstInsertPoint(if_initialized.false_block, &b_);
2206 const auto save_operand_index = [&](const IrArray::Index& operand_index) {
2207 for (int64_t i = 0; i < rank; ++i) {
2208 llvm::Value* selected_index_address_slot =
2209 InBoundsGEP(selected_index_address, {b_.getInt32(i)});
2210 Store(operand_index[i], selected_index_address_slot);
2211 }
2212 };
2213 IrArray::Index operand_index(operand_multi_index, operand_shape,
2214 index_type);
2215 llvm::Value* operand_data =
2216 operand_array.EmitReadArrayElement(operand_index, &b_);
2217 Store(operand_data, selected_value_address);
2218 save_operand_index(operand_index);
2219 Store(b_.getInt1(true), initialized_flag_address);
2220
2221 // If the initialized_flag is true, call the `select` function to
2222 // potentially update the selected value and index with the currently
2223 // visiting operand.
2224 llvm_ir::SetToFirstInsertPoint(if_initialized.true_block, &b_);
2225 llvm::Value* operand_address =
2226 operand_array.EmitArrayElementAddress(operand_index, &b_);
2227 llvm::Value* select_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
2228 llvm_ir::PrimitiveTypeToIrType(PRED,
2229 ir_emitter_context_->llvm_module()),
2230 "select_return_buffer", &b_);
2231
2232 TF_ASSIGN_OR_RETURN(
2233 const HloComputation* select_computation,
2234 GetOrCreateSubComputationFromRegion(&select_and_scatter_op.select(),
2235 /*is_fusion=*/false));
2236
2237 TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
2238 *select_computation, {selected_value_address, operand_address},
2239 select_return_buffer));
2240 llvm::Value* result = Load(select_return_buffer);
2241
2242 // If the 'select' function returns false, update the selected value and the
2243 // index to the currently visiting operand.
2244 llvm::Value* cond = ICmpNE(
2245 result,
2246 llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(
2247 PRED, ir_emitter_context_->llvm_module()),
2248 0),
2249 "boolean_predicate");
2250 llvm_ir::LlvmIfData if_select_lhs =
2251 llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_);
2252 llvm_ir::SetToFirstInsertPoint(if_select_lhs.false_block, &b_);
2253 Store(Load(operand_address), selected_value_address);
2254 save_operand_index(operand_index);
2255
2256 // After iterating over the window elements, scatter the source element to
2257 // the selected index of the output. The value we store at the output
2258 // location is computed by calling the `scatter` function with the source
2259 // value and the current output value.
2260 llvm_ir::SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(),
2261 &b_);
2262 std::vector<llvm::Value*> selected_multi_index;
2263 for (int64_t i = 0; i < rank; ++i) {
2264 llvm::Value* selected_index_address_slot =
2265 InBoundsGEP(selected_index_address, {b_.getInt32(i)});
2266 selected_multi_index.push_back(Load(selected_index_address_slot));
2267 }
2268 const Shape output_shape = GetShape(select_and_scatter_op.out());
2269 llvm::Value* source_value_address =
2270 source_array.EmitArrayElementAddress(source_index, &b_);
2271 IrArray::Index selected_index(selected_multi_index, output_shape,
2272 operand_index.GetType());
2273 llvm::Value* output_value_address =
2274 out_array.EmitArrayElementAddress(selected_index, &b_);
2275
2276 TF_ASSIGN_OR_RETURN(
2277 const HloComputation* scatter_computation,
2278 GetOrCreateSubComputationFromRegion(&select_and_scatter_op.scatter(),
2279 /*is_fusion=*/false));
2280
2281 return EmitAtomicOperationForNestedComputation(
2282 *scatter_computation, output_value_address, source_value_address);
2283 };
2284
2285 AddThunkToThunkSequence(std::move(select_and_scatter_thunk));
2286 return ParallelLoopEmitter(loop_body_emitter, source_shape, launch_dimensions,
2287 &b_)
2288 .EmitLoop(name, index_type);
2289 }
2290
EmitWhile(mlir::Operation * op)2291 Status IrEmitterUnnested::EmitWhile(mlir::Operation* op) {
2292 auto while_op = mlir::cast<mlir::lmhlo::WhileOp>(op);
2293
2294 auto cond_result = GetHloOutputs(while_op);
2295 TF_RET_CHECK(cond_result.size() == 1);
2296 TF_RET_CHECK(cond_result[0]
2297 .getType()
2298 .cast<mlir::ShapedType>()
2299 .getElementType()
2300 .isInteger(/*width=*/1))
2301 << "While condition computation must return bool";
2302
2303 // Build ForThunk for conformant while loops, otherwise build WhileThunk.
2304 if (while_op.trip_count()) {
2305 TF_ASSIGN_OR_RETURN(auto thunk, BuildForThunk(while_op, GetThunkInfo(op),
2306 *while_op.trip_count()));
2307 AddThunkToThunkSequence(std::move(thunk));
2308 } else {
2309 TF_ASSIGN_OR_RETURN(auto thunk,
2310 BuildWhileThunk(while_op, GetThunkInfo(op)));
2311 AddThunkToThunkSequence(std::move(thunk));
2312 }
2313 return Status::OK();
2314 }
2315
EmitRngGetAndUpdateState(mlir::Operation * op)2316 Status IrEmitterUnnested::EmitRngGetAndUpdateState(mlir::Operation* op) {
2317 auto rng_op = mlir::dyn_cast<mlir::lmhlo::RngGetAndUpdateStateOp>(op);
2318
2319 // Emit a kernel to increment the global state for Philox RNG algorithm.
2320 std::vector<llvm_ir::IrArray> ir_arrays;
2321 TF_ASSIGN_OR_RETURN(auto kernel_thunk,
2322 BuildKernelThunk(rng_op, rng_op.state(), GetThunkInfo(op),
2323 &ir_arrays, LaunchDimensions()));
2324 AddThunkToThunkSequence(std::move(kernel_thunk));
2325
2326 llvm::Value* old_state =
2327 llvm_ir::RngGetAndUpdateState(rng_op.delta(), module_, &b_);
2328
2329 const Shape shape = GetShape(rng_op.state());
2330
2331 llvm::Value* output_address = ir_arrays[0].EmitArrayElementAddress(
2332 llvm_ir::IrArray::Index(
2333 /*linear=*/b_.getInt64(0), shape, &b_),
2334 &b_, "rng_state_address");
2335 output_address = BitCast(
2336 output_address, llvm::PointerType::get(
2337 old_state->getType(),
2338 output_address->getType()->getPointerAddressSpace()));
2339 Store(old_state, output_address);
2340
2341 return Status::OK();
2342 }
2343
EmitScatter(mlir::Operation * op)2344 Status IrEmitterUnnested::EmitScatter(mlir::Operation* op) {
2345 ThunkSequence thunks;
2346
2347 auto scatter_op = mlir::cast<mlir::lmhlo::ScatterOp>(op);
2348
2349 if (!scatter_op.unique_indices()) {
2350 TF_RETURN_IF_ERROR(
2351 AssertNonDeterminismIsOkay(mlir::GetNameFromLoc(scatter_op.getLoc())));
2352 }
2353
2354 TF_ASSIGN_OR_RETURN(auto operand_buffer,
2355 GetAllocationSlice(scatter_op.operand()));
2356 TF_ASSIGN_OR_RETURN(auto output_buffer,
2357 GetAllocationSlice(scatter_op.output()));
2358
2359 // Copy the operand into the output if it's not the same buffer already.
2360 if (operand_buffer != output_buffer) {
2361 thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
2362 Thunk::ThunkInfo(),
2363 /*source_address=*/operand_buffer,
2364 /*destination_buffer=*/output_buffer,
2365 /*mem_size=*/
2366 ShapeUtil::ByteSizeOf(GetShape(scatter_op.output()))));
2367 }
2368
2369 const Shape& data_shape = GetShape(scatter_op.updates());
2370 TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions,
2371 CalculateLaunchDimensions(
2372 data_shape, ir_emitter_context_->gpu_device_info()));
2373
2374 // Create kernel thunk for all operands except the first one (`operand`). The
2375 // code generated for scatter below assumes that the input operand is already
2376 // copied into the output, so does not use it in codegen.
2377 std::vector<llvm_ir::IrArray> ir_arrays;
2378 thunks.emplace_back();
2379 TF_ASSIGN_OR_RETURN(
2380 thunks.back(),
2381 BuildKernelThunk(scatter_op, scatter_op.getOperands().drop_front(),
2382 GetThunkInfo(op), &ir_arrays, launch_dimensions));
2383
2384 CHECK_EQ(ir_arrays.size(), 3);
2385 const IrArray& scatter_indices = ir_arrays[0];
2386 const IrArray& updates = ir_arrays[1];
2387 const IrArray& output = ir_arrays[2];
2388
2389 auto get_index_type = [&](int64_t launch_size) {
2390 return GetIndexTypeForKernel(scatter_op, launch_size, &b_);
2391 };
2392
2393 TF_RETURN_IF_ERROR(EmitScatter(
2394 thunks.back().get(), scatter_op, launch_dimensions, output,
2395 /*scatter_indices_gen=*/
2396 [&](const IrArray::Index& index) {
2397 return scatter_indices.EmitReadArrayElement(index, &b_,
2398 "scatter_index");
2399 },
2400 /*updates_gen=*/
2401 [&](const IrArray::Index& index) {
2402 return updates.EmitReadArrayElement(index, &b_, "update");
2403 },
2404 /* get_index_type=*/
2405 get_index_type));
2406
2407 // Elide the sequential thunk if there's no copy.
2408 if (thunks.size() == 1) {
2409 AddThunkToThunkSequence(std::move(thunks[0]));
2410 } else {
2411 AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
2412 GetThunkInfo(op), std::move(thunks)));
2413 }
2414
2415 return Status::OK();
2416 }
2417
EmitScatter(Thunk * thunk,mlir::lmhlo::ScatterOp scatter,const LaunchDimensions & launch_dimensions,const llvm_ir::IrArray & output,const llvm_ir::ElementGenerator & scatter_indices_gen,const llvm_ir::ElementGenerator & updates_gen,std::function<llvm::Type * (int64_t)> get_index_type)2418 Status IrEmitterUnnested::EmitScatter(
2419 Thunk* thunk, mlir::lmhlo::ScatterOp scatter,
2420 const LaunchDimensions& launch_dimensions, const llvm_ir::IrArray& output,
2421 const llvm_ir::ElementGenerator& scatter_indices_gen,
2422 const llvm_ir::ElementGenerator& updates_gen,
2423 std::function<llvm::Type*(int64_t)> get_index_type) {
2424 const Shape operand_shape = GetShape(scatter.operand());
2425 CHECK(ShapeUtil::Equal(GetShape(scatter.output()), operand_shape));
2426
2427 TF_ASSIGN_OR_RETURN(
2428 const HloComputation* update_computation,
2429 GetOrCreateSubComputationFromRegion(&scatter.update_computation(),
2430 /*is_fusion=*/false));
2431
2432 ScatterDescriptor desc;
2433 desc.name = mlir::GetNameFromLoc(scatter.getLoc());
2434 desc.operand_shape = operand_shape;
2435 desc.scatter_indices_shape = GetShape(scatter.scatter_indices());
2436 desc.updates_shape = GetShape(scatter.updates());
2437 desc.dim_numbers = scatter.scatter_dimension_numbers();
2438 desc.unique_indices = scatter.unique_indices();
2439 desc.update_computation = update_computation;
2440 desc.output = output;
2441 desc.scatter_indices_gen = scatter_indices_gen;
2442 desc.updates_gen = updates_gen;
2443 desc.get_index_type = get_index_type;
2444 return EmitScatter(desc, thunk, launch_dimensions);
2445 }
2446
EmitScatter(const ScatterDescriptor & desc,Thunk * thunk,const LaunchDimensions & launch_dimensions)2447 Status IrEmitterUnnested::EmitScatter(
2448 const ScatterDescriptor& desc, Thunk* thunk,
2449 const LaunchDimensions& launch_dimensions) {
2450 if (!desc.unique_indices) {
2451 TF_RETURN_IF_ERROR(AssertNonDeterminismIsOkay(desc.name));
2452 }
2453 auto loop_body_emitter = [&](const IrArray::Index& index) -> Status {
2454 std::vector<llvm::Value*> raw_window_multidim;
2455 std::vector<llvm::Value*> input_scatter_multidim;
2456 std::vector<int64> raw_window_bounds;
2457
2458 // Partition the index into window indices and scatter indices.
2459 for (int64_t i = 0, e = index.size(); i != e; ++i) {
2460 // For window indices also remember the window size, this comes in handy
2461 // later.
2462 if (BinarySearchDenseElementsAttr(desc.dim_numbers.update_window_dims(),
2463 i)) {
2464 raw_window_multidim.push_back(index[i]);
2465 raw_window_bounds.push_back(desc.updates_shape.dimensions(i));
2466 } else {
2467 input_scatter_multidim.push_back(index[i]);
2468 }
2469 }
2470 DCHECK_EQ(raw_window_multidim.size(),
2471 desc.dim_numbers.update_window_dims().size());
2472
2473 // Apply inserted_window_dims to the window dimensions.
2474 int64_t raw_window_multidim_idx = 0;
2475 std::vector<llvm::Value*> input_window_multidim;
2476 std::vector<int64> input_window_bounds;
2477
2478 for (int64_t i = 0, e = desc.operand_shape.rank(); i != e; ++i) {
2479 if (BinarySearchDenseElementsAttr(desc.dim_numbers.inserted_window_dims(),
2480 i)) {
2481 input_window_bounds.push_back(1); // Trivial dimension.
2482 input_window_multidim.push_back(index.GetConstantWithIndexType(0));
2483 } else {
2484 input_window_bounds.push_back(
2485 raw_window_bounds[raw_window_multidim_idx]);
2486 input_window_multidim.push_back(
2487 raw_window_multidim[raw_window_multidim_idx]);
2488 ++raw_window_multidim_idx;
2489 }
2490 }
2491 DCHECK_EQ(input_window_multidim.size(), desc.operand_shape.rank());
2492
2493 // Insert a 1 dimension at the end if index_vector_dim requests one.
2494 Shape scatter_indices_shape_fixed = desc.scatter_indices_shape;
2495 if (desc.dim_numbers.index_vector_dim().getInt() ==
2496 desc.scatter_indices_shape.rank()) {
2497 scatter_indices_shape_fixed.add_dimensions(1);
2498 scatter_indices_shape_fixed.mutable_layout()->add_minor_to_major(
2499 desc.dim_numbers.index_vector_dim().getInt());
2500 }
2501
2502 // Now load the indices corresponding to the current window from
2503 // scatter_indices.
2504 std::vector<llvm::Value*> raw_scatter_index_multidim =
2505 input_scatter_multidim;
2506 raw_scatter_index_multidim.insert(
2507 raw_scatter_index_multidim.begin() +
2508 desc.dim_numbers.index_vector_dim().getInt(),
2509 nullptr);
2510 llvm::Value* is_in_bounds = b_.getTrue();
2511 for (int64_t i = 0,
2512 e = desc.dim_numbers.scatter_dims_to_operand_dims().size();
2513 i != e; ++i) {
2514 // Our index is stored along index_vector_dim, insert that into the lookup
2515 // index into scatter_indices.
2516 raw_scatter_index_multidim[desc.dim_numbers.index_vector_dim().getInt()] =
2517 index.GetConstantWithIndexType(i);
2518 llvm_ir::IrArray::Index raw_scatter_index_index(
2519 raw_scatter_index_multidim, scatter_indices_shape_fixed,
2520 index.GetType());
2521
2522 int64_t operand_dim =
2523 desc.dim_numbers.scatter_dims_to_operand_dims().getValue<int64>(i);
2524 TF_ASSIGN_OR_RETURN(
2525 llvm::Value* const loaded_scatter_index,
2526 desc.scatter_indices_gen(raw_scatter_index_index.SourceIndexOfReshape(
2527 scatter_indices_shape_fixed, desc.scatter_indices_shape, &b_)));
2528 // And add the index to our window index. This yields the output index.
2529 llvm::Value* casted_scatter_index =
2530 IntCast(loaded_scatter_index, index.GetType(),
2531 /*isSigned=*/true);
2532 llvm::Value* dim_offset =
2533 Add(input_window_multidim[operand_dim], casted_scatter_index);
2534 input_window_multidim[operand_dim] = dim_offset;
2535
2536 // Also do the bounds check now.
2537 int64_t max_index = desc.operand_shape.dimensions(operand_dim) -
2538 input_window_bounds[operand_dim] + 1;
2539 // is_in_bounds = index >= 0 && index < dim_size-window_size+1
2540 // --> index u< dim_size-window_size+1
2541 is_in_bounds =
2542 And(is_in_bounds, ICmpULT(casted_scatter_index,
2543 index.GetConstantWithIndexType(max_index)));
2544 }
2545
2546 llvm_ir::LlvmIfData if_window_in_bounds_data = llvm_ir::EmitIfThenElse(
2547 is_in_bounds, "scatter.in_bounds", &b_, /*emit_else=*/false);
2548 llvm_ir::SetToFirstInsertPoint(if_window_in_bounds_data.true_block, &b_);
2549 // All done, now just read from the calculated input from the window, and do
2550 // an atomic store to the calculated location in the output.
2551 llvm_ir::IrArray::Index input_window_index(
2552 input_window_multidim, desc.output.GetShape(), index.GetType());
2553 llvm::Value* output_address =
2554 desc.output.EmitArrayElementAddress(input_window_index, &b_);
2555 llvm::Value* input_address = llvm_ir::EmitAllocaAtFunctionEntry(
2556 llvm_ir::PrimitiveTypeToIrType(desc.updates_shape.element_type(),
2557 module_),
2558 "input_address", &b_);
2559 TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
2560 desc.updates_gen(index));
2561 Store(input_ir_value, input_address);
2562
2563 if (!desc.unique_indices) {
2564 return EmitAtomicOperationForNestedComputation(
2565 *desc.update_computation, output_address, input_address);
2566 } else {
2567 return EmitCallToNestedComputation(*desc.update_computation,
2568 {output_address, input_address},
2569 output_address);
2570 }
2571 };
2572
2573 // Launch a kernel that reads every element in the updates tensor. We could
2574 // also do one kernel per window instead if bounds checks turn out to be a
2575 // bottleneck.
2576 return ParallelLoopEmitter(loop_body_emitter, desc.updates_shape,
2577 launch_dimensions, &b_)
2578 .EmitLoop(desc.name,
2579 desc.get_index_type(launch_dimensions.launch_bound()));
2580 }
2581
2582 // This transformation should be migrated off. See b/171334474.
2583 StatusOr<HloComputation*>
GetOrCreateSubComputationFromRegion(mlir::Region * region,bool is_fusion)2584 IrEmitterUnnested::GetOrCreateSubComputationFromRegion(mlir::Region* region,
2585 bool is_fusion) {
2586 std::unique_ptr<HloModule>& module = scratch_nested_computations_[region];
2587 if (module == nullptr) {
2588 std::vector<Shape> operand_shapes, output_shapes;
2589 if (is_fusion) {
2590 mlir::Operation* clone = region->getParentOp()->clone();
2591 region = &mlir::cast<mlir::lmhlo::FusionOp>(clone).region();
2592 TF_RETURN_IF_ERROR(
2593 ProcessFusionForConversion(region, &operand_shapes, &output_shapes));
2594 }
2595
2596 xla::XlaComputation xla_computation;
2597 mlir::MlirToHloConversionOptions options;
2598 options.propagate_layouts = true;
2599 options.propagate_bitcast_layouts_to_backend_config = true;
2600 TF_RETURN_IF_ERROR(
2601 ConvertRegionToComputation(region, &xla_computation, options));
2602
2603 if (is_fusion) {
2604 region->getParentOp()->erase();
2605 }
2606
2607 TF_ASSIGN_OR_RETURN(auto program_shape, xla_computation.GetProgramShape());
2608 TF_ASSIGN_OR_RETURN(
2609 module, HloModule::CreateFromProto(xla_computation.proto(),
2610 HloModuleConfig(program_shape)));
2611
2612 if (is_fusion) {
2613 HloComputation* fused_computation = module->entry_computation();
2614
2615 CHECK_EQ(operand_shapes.size(), fused_computation->num_parameters());
2616 for (int i = 0; i < fused_computation->num_parameters(); i++) {
2617 *fused_computation->parameter_instruction(i)
2618 ->mutable_shape()
2619 ->mutable_layout() = operand_shapes[i].layout();
2620 }
2621 HloInstruction* root = fused_computation->root_instruction();
2622 // Manually fold Tuple(GTE(a, 0), GTE(a, 1), GTE(a, 2), ...) to a.
2623 // FusedIrEmitter doesn't take GTE ops because we aim to elimiate tuples
2624 // as much as possible.
2625 if (root->opcode() == HloOpcode::kTuple) {
2626 [&] {
2627 HloInstruction* real_root = nullptr;
2628 int expected_tuple_index = 0;
2629 for (HloInstruction* operand : root->operands()) {
2630 if (operand->opcode() != HloOpcode::kGetTupleElement) {
2631 return;
2632 }
2633 if (real_root == nullptr) {
2634 real_root = operand->mutable_operand(0);
2635 } else if (real_root != operand->operand(0)) {
2636 return;
2637 }
2638 if (expected_tuple_index != operand->tuple_index()) {
2639 return;
2640 }
2641 expected_tuple_index++;
2642 }
2643 fused_computation->set_root_instruction(real_root);
2644 std::vector<HloInstruction*> to_be_removed;
2645 to_be_removed.push_back(root);
2646 for (HloInstruction* operand : root->operands()) {
2647 to_be_removed.push_back(operand);
2648 }
2649 for (auto instr : to_be_removed) {
2650 TF_CHECK_OK(fused_computation->RemoveInstruction(instr));
2651 }
2652
2653 root = real_root;
2654 }();
2655 }
2656
2657 if (output_shapes.size() > 1) {
2658 CHECK(root->shape().IsTuple());
2659 CHECK_EQ(root->shape().tuple_shapes_size(), output_shapes.size());
2660
2661 for (int i = 0; i < output_shapes.size(); i++) {
2662 *root->mutable_shape()->mutable_tuple_shapes(i) = output_shapes.at(i);
2663 }
2664 } else {
2665 CHECK_EQ(1, output_shapes.size());
2666 *root->mutable_shape() = output_shapes[0];
2667 }
2668 }
2669 // Post-process the generated computation:
2670 // * Sanitize constant names, so that they can be used as LLVM global
2671 // symbols.
2672 // * Propagate layouts for tuple types.
2673 for (HloComputation* computation : module->computations()) {
2674 for (HloInstruction* instr : computation->MakeInstructionPostOrder()) {
2675 if (instr->opcode() == HloOpcode::kConstant) {
2676 // Notice that IR emitters use the name of constants as LLVM symbol
2677 // names, therefore it's important to not let these constants in the
2678 // new module collide with constants in the original module by names.
2679 // Unique them by prepending the module name.
2680 //
2681 // TODO(timshen): A better solution would be to plumb the exact
2682 // constant names through original HLO -> LHLO -> MHLO -> HLO. This is
2683 // hard because XLA builder doesn't support setting names. Revisit
2684 // this once we get rid of this function, or don't rely on the op name
2685 // (which shouldn't be the identity) to generate LLVM symbols.
2686 instr->SetAndSanitizeName(llvm_ir::SanitizeConstantName(
2687 module->name() + "_" + instr->name()));
2688 }
2689 if (instr->shape().IsTuple() &&
2690 computation == module->entry_computation() &&
2691 instr != computation->root_instruction()) {
2692 return InternalError("Non-root tuple types are not handled.");
2693 }
2694 }
2695 }
2696 }
2697 return module->entry_computation();
2698 }
2699
EmitSort(mlir::Operation * op)2700 Status IrEmitterUnnested::EmitSort(mlir::Operation* op) {
2701 auto sort_op = mlir::cast<mlir::lmhlo::SortOp>(op);
2702 MlirEmitterContext context;
2703 context.SetOperation(sort_op);
2704
2705 ThunkSequence thunks;
2706
2707 const Shape& keys_shape = context.operand_shapes[0];
2708 int64_t dimension_to_sort = sort_op.dimension();
2709 for (int64_t i = 0; i < context.operand_shapes.size(); ++i) {
2710 // We assume that the layout of all involved operands and outputs is the
2711 // same.
2712 TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(keys_shape,
2713 context.operand_shapes[i]));
2714 TF_RET_CHECK(
2715 LayoutUtil::LayoutsInShapesEqual(keys_shape, context.output_shapes[i]));
2716
2717 // If possible, we share buffers. If that is not possible, we need to copy
2718 // the values, because the emitter does the sorting in-place.
2719 TF_ASSIGN_OR_RETURN(auto destination_buffer,
2720 GetAllocationSlice(sort_op.output()[i]));
2721 TF_ASSIGN_OR_RETURN(auto source_address,
2722 GetAllocationSlice(sort_op.operands()[i]));
2723 if (destination_buffer != source_address) {
2724 // TODO(b/26783907): Figure out why we never seem to share buffers for
2725 // key/value sort.
2726 VLOG(2) << context.name << " requires initial D2D copy for operand " << i;
2727 thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
2728 Thunk::ThunkInfo(),
2729 /*source_address=*/source_address,
2730 /*destination_buffer=*/destination_buffer,
2731 /*mem_size=*/ShapeUtil::ByteSizeOf(context.operand_shapes[i])));
2732 }
2733 }
2734
2735 uint64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort);
2736 int64_t num_stages = tensorflow::Log2Ceiling(dimension_to_sort_bound);
2737 VLOG(2) << context.name << " requires " << num_stages << " stages.";
2738 CHECK_GE(1ULL << num_stages, dimension_to_sort_bound);
2739 CHECK_LT(1ULL << (num_stages - 1), dimension_to_sort_bound);
2740
2741 // Naive C++ code for the outer loops:
2742 //
2743 // for (int64 stage = 0; stage < Log2Ceiling(dimension_to_sort_bound);
2744 // ++stage) {
2745 // int64 first_xor_mask = (1LL << (stage + 1)) - 1;
2746 // SortInPlace(first_xor_mask);
2747 // for (int64 mask = stage - 1; mask >= 0; --mask) {
2748 // int64 later_xor_mask = 1LL << mask;
2749 // SortInPlace(later_xor_mask);
2750 // }
2751 // }
2752 //
2753 // This follows the alternative representation of the algorithm described on
2754 // Wikipedia: https://en.wikipedia.org/wiki/Bitonic_sorter
2755 //
2756 // Each mask specifies how to derive from one position in the array the
2757 // position with which it should be compared (we calculate the xor of the
2758 // position with the mask).
2759 // As an optimization, we can move the 'mask' loop to inside the
2760 // sorting/comparison loop if the comparisons happen within a small block of
2761 // the array. To make this work, we collect all consecutive masks that are
2762 // smaller than our chosen power of 2 tile size, and pass them to SortInPlace.
2763 // Each thread then processes one tile of data.
2764
2765 const uint64 kTileSize = std::min(2048ULL, 1ULL << num_stages);
2766
2767 // If we cannot combine several xor masks together, we don't use tiling, so we
2768 // calculate the standard launch dimensions for the shape. However we only
2769 // need to iterate through ~half of the dimension to sort (rounded up to the
2770 // next highest power of 2), because each iteration compares one pair of
2771 // elements.
2772 Shape standard_iteration_shape = keys_shape;
2773 uint64 standard_num_iterations_in_sort_dim = 1ULL << (num_stages - 1);
2774 standard_iteration_shape.set_dimensions(dimension_to_sort,
2775 standard_num_iterations_in_sort_dim);
2776 TF_ASSIGN_OR_RETURN(
2777 LaunchDimensions standard_launch_dimensions,
2778 CalculateLaunchDimensions(standard_iteration_shape,
2779 ir_emitter_context_->gpu_device_info()));
2780
2781 // Calculate the launch dimensions for the case where we use tiling. We split
2782 // the dimension that should be sorted into tiles of size 'kTileSize'. This
2783 // means we first need to round 'dimension_to_sort_bound' up to be a multiple
2784 // of the tile size.
2785 int64_t rounded_bound = RoundUpToNearest(dimension_to_sort_bound, kTileSize);
2786 Shape iteration_shape = keys_shape;
2787
2788 // We iterate through the element pairs that should be compared.
2789 uint64 num_iterations_in_sort_dim = rounded_bound / 2;
2790 iteration_shape.set_dimensions(dimension_to_sort, num_iterations_in_sort_dim);
2791 uint64 num_iterations = ShapeUtil::ElementsIn(iteration_shape);
2792
2793 // For correctness reasons we need exactly 'kTileSize' / 2 many threads per
2794 // block. Each thread is responsible for copying exactly two adjacent elements
2795 // into shared memory, and then does a comparison of two possibly different
2796 // elements taken from shared memory.
2797 const uint64 kThreadsPerBlock = kTileSize / 2;
2798
2799 // Check whether we should use any tiling. We might not be able to use it if
2800 // we have not enough threads, or not enough shared memory. Also it does not
2801 // give a speedup if the tile size is < 128.
2802 int64_t total_shared_memory_needed = 0;
2803 for (int64_t i = 0; i < context.operand_shapes.size(); ++i) {
2804 total_shared_memory_needed +=
2805 kTileSize * ShapeUtil::ByteSizeOfPrimitiveType(
2806 context.operand_shapes[i].element_type());
2807 }
2808 bool no_tiling =
2809 kTileSize < 128 ||
2810 kThreadsPerBlock >
2811 ir_emitter_context_->gpu_device_info().threads_per_block_limit ||
2812 total_shared_memory_needed >
2813 ir_emitter_context_->gpu_device_info().shared_memory_per_block;
2814 VLOG(2) << absl::StreamFormat(
2815 "%s %s use tiling. No tiling if any of the following is true: "
2816 "kTileSize=%d < 128, "
2817 "kThreadsPerBlock=%d > threads_per_block_limit=%d, "
2818 "total_shared_memory_needed=%d > shared_memory_per_block=%d",
2819 context.name, (no_tiling ? "won't" : "will"), kTileSize, kThreadsPerBlock,
2820 ir_emitter_context_->gpu_device_info().threads_per_block_limit,
2821 total_shared_memory_needed,
2822 ir_emitter_context_->gpu_device_info().shared_memory_per_block);
2823
2824 uint64 num_blocks = CeilOfRatio(num_iterations, kThreadsPerBlock);
2825 LaunchDimensions tiled_launch_dimensions(num_blocks, kThreadsPerBlock);
2826 VLOG(2) << absl::StreamFormat("%s launch dims: %d blocks, %d threads/block",
2827 context.name, num_blocks, kThreadsPerBlock);
2828
2829 std::vector<llvm_ir::IrArray> ir_arrays;
2830 auto emit_kernel = [&](absl::Span<const int64> xor_masks) {
2831 VLOG(2) << absl::StreamFormat(
2832 "%s uses kernel for xor masks [%s]", context.name,
2833 absl::StrJoin(xor_masks, ", ", [](std::string* out, int64_t xor_mask) {
2834 absl::StrAppendFormat(out, "0x%x", xor_mask);
2835 }));
2836 thunks.emplace_back();
2837 LaunchDimensions launch_dimensions = xor_masks.size() > 1
2838 ? tiled_launch_dimensions
2839 : standard_launch_dimensions;
2840 TF_ASSIGN_OR_RETURN(
2841 thunks.back(),
2842 BuildKernelThunk(sort_op, sort_op.output(), Thunk::ThunkInfo(),
2843 &ir_arrays, launch_dimensions));
2844 std::vector<IrArray> values_arrays;
2845 values_arrays.reserve(context.operand_shapes.size());
2846 for (int64_t i = 0; i < context.operand_shapes.size(); ++i) {
2847 values_arrays.push_back(ir_arrays[i]);
2848 }
2849 TF_ASSIGN_OR_RETURN(const HloComputation* comparator,
2850 GetOrCreateSubComputationFromRegion(
2851 &sort_op.comparator(), /*is_fusion=*/false));
2852 return llvm_ir::EmitSortInPlace(
2853 dimension_to_sort, values_arrays, IrName(context.name), xor_masks, &b_,
2854 launch_dimensions,
2855 xor_masks.size() > 1 ? num_iterations_in_sort_dim
2856 : standard_num_iterations_in_sort_dim,
2857 kTileSize,
2858 [&](absl::Span<llvm::Value* const> operands, llvm::Value* output) {
2859 return EmitCallToNestedComputation(*comparator, operands, output);
2860 });
2861 };
2862 std::vector<int64> xor_masks;
2863 for (int64_t stage = 0; stage < num_stages; ++stage) {
2864 for (int64_t mask = stage; mask >= 0; --mask) {
2865 int64_t xor_mask;
2866 if (mask == stage) {
2867 xor_mask = (1LL << (stage + 1)) - 1;
2868 } else {
2869 xor_mask = 1LL << mask;
2870 }
2871 if (xor_mask >= kTileSize || no_tiling) {
2872 if (!xor_masks.empty()) {
2873 TF_RETURN_IF_ERROR(emit_kernel(xor_masks));
2874 xor_masks.clear();
2875 }
2876 TF_RETURN_IF_ERROR(emit_kernel({xor_mask}));
2877 } else {
2878 xor_masks.push_back(xor_mask);
2879 }
2880 }
2881 }
2882 if (!xor_masks.empty()) {
2883 TF_RETURN_IF_ERROR(emit_kernel(xor_masks));
2884 }
2885 VLOG(2) << absl::StreamFormat(
2886 "%s requires %d thunks (including any D2D copies)", context.name,
2887 thunks.size());
2888
2889 AddThunkToThunkSequence(
2890 absl::make_unique<SequentialThunk>(GetThunkInfo(op), std::move(thunks)));
2891 return Status::OK();
2892 }
2893
2894 template <typename ThunkType, typename OpT>
EmitReplicaOrPartitionId(mlir::Operation * op)2895 Status IrEmitterUnnested::EmitReplicaOrPartitionId(mlir::Operation* op) {
2896 auto casted = mlir::cast<OpT>(op);
2897 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice,
2898 GetAllocationSlice(casted.getOperand()));
2899 AddThunkToThunkSequence(
2900 absl::make_unique<ThunkType>(GetThunkInfo(op), result_slice));
2901 return Status::OK();
2902 }
2903
EmitCollectivePermute(mlir::Operation * op)2904 Status IrEmitterUnnested::EmitCollectivePermute(mlir::Operation* op) {
2905 auto collective_permute_op = mlir::cast<mlir::lmhlo::CollectivePermuteOp>(op);
2906
2907 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice source_slice,
2908 GetAllocationSlice(collective_permute_op.operand()));
2909 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice,
2910 GetAllocationSlice(collective_permute_op.output()));
2911
2912 const Shape shape = GetShape(collective_permute_op.operand());
2913 const int64_t replica_count = hlo_module_config_.replica_count();
2914 const int64_t partition_count = hlo_module_config_.num_partitions();
2915
2916 if (NcclCollectivePermuteThunk::IsDegenerate(
2917 collective_permute_op, replica_count, partition_count)) {
2918 // For a degenerate collective permute, just generate a copy thunk.
2919 AddThunkToThunkSequence(absl::make_unique<DeviceToDeviceCopyThunk>(
2920 GetThunkInfo(op),
2921 /*source_address=*/source_slice,
2922 /*destination_buffer=*/result_slice,
2923 /*mem_size=*/ShapeUtil::ByteSizeOf(shape)));
2924 } else {
2925 const NcclCollectivePermuteThunk::Buffer buffer = {
2926 /*element_count=*/ShapeUtil::ElementsIn(shape),
2927 /*source_buffer=*/source_slice,
2928 /*destination_buffer=*/result_slice};
2929 AddThunkToThunkSequence(absl::make_unique<NcclCollectivePermuteThunk>(
2930 GetThunkInfo(op), collective_permute_op, replica_count, partition_count,
2931 buffer));
2932 }
2933 return Status::OK();
2934 }
2935
2936 template <typename NcclThunkType>
MaybeAddAllReduceStartThunkToMap(absl::flat_hash_map<mlir::Operation *,NcclAllReduceStartThunk * > &,mlir::Operation * op,NcclThunkType * thunk)2937 Status MaybeAddAllReduceStartThunkToMap(
2938 absl::flat_hash_map<mlir::Operation*, NcclAllReduceStartThunk*>&,
2939 mlir::Operation* op, NcclThunkType* thunk) {
2940 return Status::OK();
2941 }
2942
2943 template <>
MaybeAddAllReduceStartThunkToMap(absl::flat_hash_map<mlir::Operation *,NcclAllReduceStartThunk * > & all_reduce_start_thunks,mlir::Operation * op,NcclAllReduceStartThunk * thunk)2944 Status MaybeAddAllReduceStartThunkToMap(
2945 absl::flat_hash_map<mlir::Operation*, NcclAllReduceStartThunk*>&
2946 all_reduce_start_thunks,
2947 mlir::Operation* op, NcclAllReduceStartThunk* thunk) {
2948 TF_RET_CHECK(all_reduce_start_thunks.emplace(op, thunk).second)
2949 << "all-reduce-start with this unique ID already seen";
2950 return Status::OK();
2951 }
2952
2953 template <typename NcclThunkType, typename OpTy>
EmitNcclThunk(mlir::Operation * untyped_op)2954 Status IrEmitterUnnested::EmitNcclThunk(mlir::Operation* untyped_op) {
2955 OpTy op = mlir::cast<OpTy>(untyped_op);
2956 int64_t replica_count = hlo_module_config_.replica_count();
2957 int64_t partition_count = hlo_module_config_.num_partitions();
2958 VLOG(2) << NcclThunkType::GetName() << "; replica count: " << replica_count
2959 << "; partition count: " << partition_count
2960 << "; operand count: " << op.operands().size()
2961 << "; NCCL is enabled: " << NcclThunkType::NcclIsEnabled();
2962
2963 // A given collective op can be degenerate if across all groups formed
2964 // by it are singleton. In such a case, we don't need to do any communication
2965 // and we can just copy the input to the output.
2966 bool is_degenerate =
2967 NcclThunkType::IsDegenerate(op, replica_count, partition_count);
2968 bool should_use_nccl_thunk =
2969 !is_degenerate && NcclThunkType::CanImplement(op);
2970
2971 // Stash relevant information in NcclCollectiveThunk::Buffer even if we may
2972 // not generate an NcclCollectiveThunk.
2973 std::vector<NcclCollectiveThunk::Buffer> buffers;
2974 buffers.reserve(op.operands().size());
2975 for (auto it : llvm::zip(op.operands(), op.results())) {
2976 mlir::Value operand = std::get<0>(it);
2977 mlir::Value result = std::get<1>(it);
2978 const Shape shape = GetShape(operand);
2979 TF_ASSIGN_OR_RETURN(auto source_slice, GetAllocationSlice(operand));
2980 TF_ASSIGN_OR_RETURN(auto dest_slice, GetAllocationSlice(result));
2981 buffers.push_back(NcclCollectiveThunk::Buffer{
2982 /*element_count=*/ShapeUtil::ElementsIn(shape),
2983 /*source_buffer=*/source_slice,
2984 /*destination_buffer=*/dest_slice});
2985 }
2986
2987 if (should_use_nccl_thunk) {
2988 auto nccl_thunk =
2989 absl::make_unique<NcclThunkType>(GetThunkInfo(op), op,
2990 /*buffers=*/std::move(buffers));
2991 // Record thunks for all-reduce-start ops as the done ops need them.
2992 TF_RETURN_IF_ERROR(MaybeAddAllReduceStartThunkToMap(
2993 all_reduce_start_thunks_, op, nccl_thunk.get()));
2994 AddThunkToThunkSequence(std::move(nccl_thunk));
2995 return Status::OK();
2996 }
2997
2998 // Signal that all-reduce-start thunk not created with nullptr.
2999 TF_RETURN_IF_ERROR(MaybeAddAllReduceStartThunkToMap(
3000 all_reduce_start_thunks_, op, static_cast<NcclThunkType*>(nullptr)));
3001
3002 if (!is_degenerate) {
3003 CollectiveOpGroupMode group_mode = NcclThunkType::GetGroupMode(op);
3004
3005 string message = absl::StrFormat(
3006 "Requested %s not implemented on GPU; replica_count: %d; "
3007 "partition_count: %d, group_mode: %s, operand_count: %d; NCCL support: "
3008 "%d",
3009 NcclThunkType::GetName(), replica_count, partition_count,
3010 CollectiveOpGroupModeToString(group_mode), op.operands().size(),
3011 NcclThunkType::NcclIsEnabled());
3012 if (!op.operands().empty()) {
3013 const Shape shape = GetShape(op.operands().front());
3014 absl::StrAppendFormat(&message, "; first operand array element-type: %s",
3015 PrimitiveType_Name(shape.element_type()));
3016 }
3017 return Unimplemented("%s", message);
3018 }
3019
3020 // All-gather with one replica is simply the identity function. Buffer
3021 // assignment expects a copy, so that's what we do.
3022 ThunkSequence thunks;
3023 for (int64_t i = 0; i < buffers.size(); i++) {
3024 const Shape shape = GetShape(op.operands()[i]);
3025 thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
3026 buffers.size() == 1 ? GetThunkInfo(op) : Thunk::ThunkInfo(),
3027 /*source_address=*/buffers[i].source_buffer,
3028 /*destination_buffer=*/buffers[i].destination_buffer,
3029 /*mem_size=*/ShapeUtil::ByteSizeOf(shape)));
3030 }
3031 if (thunks.size() == 1) {
3032 AddThunkToThunkSequence(std::move(thunks[0]));
3033 } else {
3034 AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
3035 GetThunkInfo(op), std::move(thunks)));
3036 }
3037 return Status::OK();
3038 }
3039
EmitAllReduceDone(mlir::Operation * op)3040 Status IrEmitterUnnested::EmitAllReduceDone(mlir::Operation* op) {
3041 auto done_op = mlir::cast<mlir::lmhlo_gpu::AllReduceDoneOp>(op);
3042 auto start_op =
3043 done_op.token().getDefiningOp<mlir::lmhlo_gpu::AllReduceStartOp>();
3044 auto it = all_reduce_start_thunks_.find(start_op);
3045 TF_RET_CHECK(it != all_reduce_start_thunks_.end())
3046 << "couldn't find thunk for all-reduce-start op";
3047
3048 // Can be null if no all-reduce-start thunk was created (e.g. if the start op
3049 // is degenerate), in which case there's nothing to do here.
3050 if (it->second != nullptr) {
3051 AddThunkToThunkSequence(absl::make_unique<NcclAllReduceDoneThunk>(
3052 GetThunkInfo(op), *it->second));
3053 }
3054 all_reduce_start_thunks_.erase(it);
3055 return Status::OK();
3056 }
3057
EmitInfeed(mlir::Operation * op)3058 Status IrEmitterUnnested::EmitInfeed(mlir::Operation* op) {
3059 auto infeed_op = mlir::cast<mlir::lmhlo::InfeedOp>(op);
3060
3061 std::vector<ShapedSlice> dest_slices;
3062 dest_slices.reserve(infeed_op.outputs().size());
3063
3064 for (mlir::Value output : infeed_op.outputs()) {
3065 TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(output));
3066 const Shape& shape = GetShape(output);
3067 dest_slices.push_back(ShapedSlice{slice, shape});
3068 }
3069
3070 AddThunkToThunkSequence(
3071 absl::make_unique<InfeedThunk>(GetThunkInfo(op), std::move(dest_slices)));
3072 return Status::OK();
3073 }
3074
EmitOutfeed(mlir::Operation * op)3075 Status IrEmitterUnnested::EmitOutfeed(mlir::Operation* op) {
3076 auto outfeed_op = mlir::cast<mlir::lmhlo::OutfeedOp>(op);
3077
3078 std::vector<ShapedSlice> source_slices;
3079 source_slices.reserve(outfeed_op.operands().size());
3080
3081 for (mlir::Value operand : outfeed_op.operands()) {
3082 TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSlice(operand));
3083 const Shape& shape = GetShape(operand);
3084 source_slices.push_back(ShapedSlice{slice, shape});
3085 }
3086
3087 AddThunkToThunkSequence(absl::make_unique<OutfeedThunk>(
3088 GetThunkInfo(op), std::move(source_slices)));
3089 return Status::OK();
3090 }
3091
BuildKernelThunkImpl(absl::string_view name,Thunk::ThunkInfo thunk_info,absl::Span<const BufferSlice> slices,std::vector<llvm_ir::IrArray> * ir_arrays,const LaunchDimensions & launch_dimensions)3092 std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunkImpl(
3093 absl::string_view name, Thunk::ThunkInfo thunk_info,
3094 absl::Span<const BufferSlice> slices,
3095 std::vector<llvm_ir::IrArray>* ir_arrays,
3096 const LaunchDimensions& launch_dimensions) {
3097 // Figure out which buffer allocations need to be passed as arguments to our
3098 // kernel. This is simply all of the allocations referenced in slices,
3099 // plus the XLA temp buffer (if we have it). We always include the temp
3100 // buffer because even if the kernel itself doesn't use it, a nested
3101 // subcomputation within the kernel (e.g. a kMap's computation) might.
3102 std::unordered_set<const BufferAllocation*> buffers_needed;
3103 for (const auto& slice : slices) {
3104 buffers_needed.insert(slice.buffer_slice.allocation());
3105 }
3106 absl::optional<const BufferAllocation*> temp_buffer;
3107 for (const BufferAllocation& alloc : ir_emitter_context_->allocations()) {
3108 if (alloc.IsPreallocatedTempBuffer()) {
3109 if (!temp_buffer.has_value()) {
3110 // Retrieve the first seen temp buffer.
3111 temp_buffer = &alloc;
3112 }
3113 }
3114 }
3115 if (temp_buffer.has_value()) {
3116 buffers_needed.insert(*temp_buffer);
3117 }
3118
3119 // We'll pass a pointer to each of the elements of `buffers` to our kernel, in
3120 // this order.
3121 std::vector<const BufferAllocation*> non_constant_buffers;
3122 absl::c_copy_if(buffers_needed, std::back_inserter(non_constant_buffers),
3123 [](const BufferAllocation* allocation) {
3124 return !allocation->is_constant();
3125 });
3126
3127 absl::c_sort(non_constant_buffers,
3128 [](const BufferAllocation* a, const BufferAllocation* b) {
3129 return a->index() < b->index();
3130 });
3131
3132 llvm::Function* kernel = BuildKernelPrototype(name, non_constant_buffers);
3133
3134 // Build a map from a BufferAllocation to the corresponding argument in our
3135 // kernel.
3136 std::unordered_map<const BufferAllocation*, llvm::Value*> kernel_args;
3137 {
3138 auto arg_it = kernel->arg_begin();
3139 auto buffers_it = non_constant_buffers.begin();
3140 for (; arg_it != kernel->arg_end(); ++arg_it, ++buffers_it) {
3141 kernel_args[*buffers_it] = arg_it;
3142
3143 // Annotate all allocations with LLVM's `noalias`.
3144 // There are three kinds of allocations:
3145 // * Read-only allocations, aka input parameters that are not aliased with
3146 // outputs.
3147 // * Read-write allocations, including all output buffers, some of which
3148 // may alias with input HLO parameters, but aliased HLO buffers are always
3149 // assigned with the same allocation.
3150 // * The temp buffer.
3151 //
3152 // Read-only allocations may overlap with each other, but since they are
3153 // not mutated, they can always be annotated with `noalias` per LLVM
3154 // semantics.
3155 //
3156 // Read-write allocations and the temp buffer don't overlap with any
3157 // allocations, therefore they can also be annotated with `noalias`.
3158 kernel->addParamAttr(
3159 arg_it->getArgNo(),
3160 llvm::Attribute::get(arg_it->getContext(), llvm::Attribute::NoAlias));
3161 }
3162 }
3163
3164 absl::flat_hash_set<BufferAllocation::Slice> buffers_written;
3165 for (const auto& slice : slices) {
3166 if (slice.written) {
3167 buffers_written.insert(slice.buffer_slice);
3168 }
3169 }
3170
3171 ir_arrays->clear();
3172
3173 // For each buffer our kernel might want to touch, bind it to a value derived
3174 // from our kernel args.
3175 for (const auto& slice : slices) {
3176 const BufferAllocation::Slice& buffer_slice = slice.buffer_slice;
3177
3178 llvm::Value* loc;
3179 if (!slice.constant_name.empty()) {
3180 loc = ir_emitter_context_->llvm_module()->getGlobalVariable(
3181 slice.constant_name);
3182 CHECK_NE(loc, nullptr);
3183 } else {
3184 CHECK(!buffer_slice.allocation()->is_constant());
3185 loc = InBoundsGEP(kernel_args.at(buffer_slice.allocation()),
3186 {b_.getInt64(buffer_slice.offset())});
3187 }
3188
3189 llvm_ir::IrArray ir_array(CastToTypedValue(slice.shape, loc, &b_),
3190 slice.shape);
3191 if (!buffers_written.contains(slice.buffer_slice)) {
3192 ir_array.MarkInvariantOverWholeProgram(&loc->getContext());
3193 }
3194
3195 ir_arrays->push_back(ir_array);
3196 }
3197
3198 AnnotateThunkLaunchDimensions(launch_dimensions,
3199 std::string(kernel->getName()),
3200 ir_emitter_context_->llvm_module());
3201 return absl::make_unique<KernelThunk>(thunk_info, non_constant_buffers,
3202 std::string(kernel->getName()),
3203 launch_dimensions);
3204 }
3205
BuildKernelThunk(mlir::Operation * op,mlir::ValueRange operands,Thunk::ThunkInfo thunk_info,std::vector<llvm_ir::IrArray> * ir_arrays,const LaunchDimensions & launch_dimensions)3206 StatusOr<std::unique_ptr<KernelThunk>> IrEmitterUnnested::BuildKernelThunk(
3207 mlir::Operation* op, mlir::ValueRange operands, Thunk::ThunkInfo thunk_info,
3208 std::vector<llvm_ir::IrArray>* ir_arrays,
3209 const LaunchDimensions& launch_dimensions) {
3210 TF_RET_CHECK(!mlir::isa<mlir::lmhlo::FusionOp>(op));
3211
3212 std::vector<BufferSlice> slices;
3213 for (mlir::Value operand : operands) {
3214 slices.emplace_back();
3215 auto& slice = slices.back();
3216 TF_ASSIGN_OR_RETURN(slice.buffer_slice,
3217 GetAllocationSlice(operand, &slice.constant_name));
3218 slice.written = WritesMlirBuffer(op, operand);
3219 slice.shape = GetShape(operand);
3220 }
3221 std::string name = mlir::GetNameFromLoc(op->getLoc());
3222 return BuildKernelThunkImpl(name, thunk_info, slices, ir_arrays,
3223 launch_dimensions);
3224 }
3225
BuildKernelThunk(mlir::Operation * op,Thunk::ThunkInfo thunk_info,std::vector<llvm_ir::IrArray> * ir_arrays,const LaunchDimensions & launch_dimensions)3226 StatusOr<std::unique_ptr<KernelThunk>> IrEmitterUnnested::BuildKernelThunk(
3227 mlir::Operation* op, Thunk::ThunkInfo thunk_info,
3228 std::vector<llvm_ir::IrArray>* ir_arrays,
3229 const LaunchDimensions& launch_dimensions) {
3230 if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
3231 auto operands = GetHloOperands(op);
3232 auto outputs = GetHloOutputs(op);
3233
3234 std::vector<BufferSlice> slices;
3235 for (auto operand : operands) {
3236 slices.emplace_back();
3237 auto& slice = slices.back();
3238 TF_ASSIGN_OR_RETURN(slice.buffer_slice,
3239 GetAllocationSlice(operand, &slice.constant_name));
3240 slice.written = false;
3241 slice.shape = GetShape(operand);
3242 }
3243 for (auto output : outputs) {
3244 slices.emplace_back();
3245 auto& slice = slices.back();
3246 TF_ASSIGN_OR_RETURN(slice.buffer_slice,
3247 GetAllocationSlice(output, &slice.constant_name));
3248 slice.written = true;
3249 slice.shape = GetShape(output);
3250 }
3251 std::string name = mlir::GetNameFromLoc(op->getLoc());
3252 return BuildKernelThunkImpl(name, thunk_info, slices, ir_arrays,
3253 launch_dimensions);
3254 }
3255 return BuildKernelThunk(op, op->getOperands(), thunk_info, ir_arrays,
3256 launch_dimensions);
3257 }
3258
BuildConstantInitializerThunk(absl::Span<const uint8> init_value,const BufferAllocation::Slice & dest,const Shape & output_shape)3259 std::unique_ptr<Thunk> IrEmitterUnnested::BuildConstantInitializerThunk(
3260 absl::Span<const uint8> init_value, const BufferAllocation::Slice& dest,
3261 const Shape& output_shape) {
3262 int64_t num_bytes = init_value.size();
3263 if (absl::c_all_of(init_value, [](uint8 byte) { return byte == 0; })) {
3264 return absl::make_unique<MemzeroThunk>(Thunk::ThunkInfo(), dest);
3265 }
3266
3267 // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by
3268 // repeating the literal 4 or 2 times, so long as the destination buffer is
3269 // an even multiple of 32 bits long.
3270 if ((num_bytes == 1 || num_bytes == 2) &&
3271 ShapeUtil::ByteSizeOf(output_shape) % 4 == 0) {
3272 uint16 pattern16;
3273 if (num_bytes == 1) {
3274 uint8 b = init_value.front();
3275 pattern16 = uint16{b} | (uint16{b} << 8);
3276 } else {
3277 memcpy(&pattern16, init_value.data(), sizeof(pattern16));
3278 }
3279 uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16);
3280 return absl::make_unique<Memset32BitValueThunk>(Thunk::ThunkInfo(),
3281 pattern32, dest);
3282 }
3283
3284 // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit
3285 // memset so long as all 32-bit words of the scalar are equal to each other.
3286 if (num_bytes >= 4 && num_bytes % 4 == 0 &&
3287 memcmp(init_value.data(), init_value.data() + 4, init_value.size() - 4) ==
3288 0) {
3289 uint32 word;
3290 memcpy(&word, init_value.data(), sizeof(word));
3291 return absl::make_unique<Memset32BitValueThunk>(Thunk::ThunkInfo(), word,
3292 dest);
3293 }
3294
3295 return nullptr;
3296 }
3297
3298 StatusOr<std::unique_ptr<Thunk>>
TryBuildConstantInitializerThunk(mlir::Value init_value,mlir::Value dest)3299 IrEmitterUnnested::TryBuildConstantInitializerThunk(mlir::Value init_value,
3300 mlir::Value dest) {
3301 mlir::DenseElementsAttr const_init;
3302 if (auto get_global_memref =
3303 mlir::dyn_cast_or_null<mlir::memref::GetGlobalOp>(
3304 init_value.getDefiningOp())) {
3305 auto global_memref =
3306 mlir::SymbolTable::lookupNearestSymbolFrom<mlir::memref::GlobalOp>(
3307 get_global_memref, get_global_memref.name());
3308 if (global_memref.constant() && global_memref.initial_value()) {
3309 // If the initial value happens to be a constant, generate a specialized
3310 // thunk.
3311 const_init = global_memref.initial_value()
3312 .getValue()
3313 .cast<mlir::DenseElementsAttr>();
3314 }
3315 } else if (auto constant = mlir::dyn_cast_or_null<mlir::mhlo::ConstOp>(
3316 init_value.getDefiningOp())) {
3317 const_init = constant.value().dyn_cast<mlir::DenseElementsAttr>();
3318 }
3319
3320 if (const_init) {
3321 std::vector<uint8> literal_bytes;
3322 TF_RETURN_IF_ERROR(
3323 CopyDenseElementsDataToXlaFormat(const_init, &literal_bytes));
3324
3325 TF_ASSIGN_OR_RETURN(auto dest_slice, GetAllocationSlice(dest));
3326
3327 const Shape dest_shape = GetShape(dest);
3328 auto thunk =
3329 BuildConstantInitializerThunk(literal_bytes, dest_slice, dest_shape);
3330 if (thunk) {
3331 return {std::move(thunk)};
3332 }
3333 }
3334 return std::unique_ptr<Thunk>();
3335 }
3336
BuildInitializerThunk(mlir::Operation * op,mlir::Value init_value,mlir::Value dest)3337 StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildInitializerThunk(
3338 mlir::Operation* op, mlir::Value init_value, mlir::Value dest) {
3339 // initial value must be a scalar memref.
3340 auto init_type = init_value.getType().dyn_cast<mlir::MemRefType>();
3341 TF_RET_CHECK(init_type.getRank() == 0);
3342
3343 TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> constant_init_thunk,
3344 TryBuildConstantInitializerThunk(init_value, dest));
3345 if (constant_init_thunk) {
3346 return {std::move(constant_init_thunk)};
3347 }
3348
3349 // Otherwise fall back to our slow initializer code. The thunk in this case
3350 // will just need the IR arrays for the initial value and the destination.
3351 const Shape dest_shape = GetShape(dest);
3352 TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions,
3353 CalculateLaunchDimensions(
3354 dest_shape, ir_emitter_context_->gpu_device_info()));
3355 std::vector<llvm_ir::IrArray> ir_arrays;
3356 TF_ASSIGN_OR_RETURN(
3357 std::unique_ptr<KernelThunk> kernel_thunk,
3358 BuildKernelThunk(op, {init_value, dest}, Thunk::ThunkInfo(), &ir_arrays,
3359 launch_dimensions));
3360
3361 const llvm_ir::IrArray init_array = ir_arrays[0];
3362 const llvm_ir::IrArray dest_array = ir_arrays[1];
3363
3364 std::string name = mlir::GetNameFromLoc(op->getLoc());
3365 TF_RETURN_IF_ERROR(ParallelLoopEmitter(
3366 [=](const IrArray::Index& index) {
3367 return init_array.EmitReadArrayElement(index, &b_);
3368 },
3369 dest_array, launch_dimensions, &b_)
3370 .EmitLoop(mlir::GetNameFromLoc(op->getLoc())));
3371
3372 // Convert unique_ptr<KernelThunk> to StatusOr<unique_ptr<Thunk>>.
3373 return {std::move(kernel_thunk)};
3374 }
3375
BuildFusedInitializerThunk(mlir::lmhlo::FusionOp fusion,int output_index)3376 StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildFusedInitializerThunk(
3377 mlir::lmhlo::FusionOp fusion, int output_index) {
3378 auto reduce = mlir::dyn_cast_or_null<mlir::mhlo::ReduceOp>(
3379 fusion.getFusionResults()[output_index].getDefiningOp());
3380
3381 TF_RET_CHECK(reduce);
3382 TF_RET_CHECK(reduce.getNumResults() == 1);
3383
3384 mlir::Value init_value = reduce.init_values()[0];
3385 mlir::Value dest = fusion.getOutputBuffers()[output_index];
3386 TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> constant_init_thunk,
3387 TryBuildConstantInitializerThunk(init_value, dest));
3388 if (constant_init_thunk) {
3389 return {std::move(constant_init_thunk)};
3390 }
3391
3392 auto input_buffers = fusion.getInputBuffers();
3393
3394 const Shape dest_shape = GetShape(dest);
3395 TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions,
3396 CalculateLaunchDimensions(
3397 dest_shape, ir_emitter_context_->gpu_device_info()));
3398 std::vector<llvm_ir::IrArray> ir_arrays;
3399 TF_ASSIGN_OR_RETURN(std::unique_ptr<KernelThunk> kernel_thunk,
3400 BuildKernelThunk(fusion, Thunk::ThunkInfo(), &ir_arrays,
3401 launch_dimensions));
3402
3403 const llvm_ir::IrArray dest_array =
3404 ir_arrays[input_buffers.size() + output_index];
3405
3406 const HloComputation* fused_computation =
3407 *GetOrCreateSubComputationFromRegion(&fusion.region(),
3408 /*is_fusion=*/true);
3409
3410 // If init_value was fused into this reduce we have to generate it first.
3411 GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
3412 ir_emitter_context_->llvm_module(),
3413 &b_, GetNestedComputer());
3414
3415 FusedIrEmitter fused_emitter(&elemental_emitter);
3416 for (int i = 0; i < fused_computation->num_parameters(); i++) {
3417 fused_emitter.BindGenerator(
3418 fused_computation->parameter_instruction(i),
3419 [this, &ir_arrays, i](llvm_ir::IrArray::Index index) {
3420 return ir_arrays[i].EmitReadArrayElement(index, &b_);
3421 });
3422 }
3423 HloInstruction* instr = fused_computation->root_instruction();
3424 if (instr->opcode() != HloOpcode::kTuple) {
3425 CHECK_EQ(0, output_index);
3426 } else {
3427 instr = instr->mutable_operand(output_index);
3428 }
3429 TF_RET_CHECK(instr->shape().IsArray());
3430 TF_ASSIGN_OR_RETURN(auto generator,
3431 fused_emitter.GetGenerator(instr->operand(1)));
3432 TF_RETURN_IF_ERROR(
3433 ParallelLoopEmitter(generator, dest_array, launch_dimensions, &b_)
3434 .EmitLoop(mlir::GetNameFromLoc(fusion.getLoc())));
3435 return {std::move(kernel_thunk)};
3436 }
3437
BuildWhileThunk(mlir::lmhlo::WhileOp while_op,const Thunk::ThunkInfo & thunk_info)3438 StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildWhileThunk(
3439 mlir::lmhlo::WhileOp while_op, const Thunk::ThunkInfo& thunk_info) {
3440 // Generate thunk sequence for while 'condition'.
3441 mlir::Region* condition = &while_op.cond();
3442 TF_ASSIGN_OR_RETURN(
3443 auto ir_emitter_condition,
3444 IrEmitterUnnested::Create(hlo_module_config_, ir_emitter_context_));
3445
3446 TF_RETURN_IF_ERROR(ir_emitter_condition->EmitLmhloRegion(condition));
3447
3448 // Generate thunk sequence for while 'body'.
3449 mlir::Region* body = &while_op.body();
3450 TF_ASSIGN_OR_RETURN(
3451 auto ir_emitter_body,
3452 IrEmitterUnnested::Create(hlo_module_config_, ir_emitter_context_));
3453
3454 TF_RETURN_IF_ERROR(ir_emitter_body->EmitLmhloRegion(body));
3455
3456 // Extract the condition value from the last op (exlucidng the terminator op)
3457 // in the condition region.
3458 auto cond_result = GetHloOutputs(while_op);
3459 TF_RET_CHECK(cond_result.size() == 1);
3460 TF_ASSIGN_OR_RETURN(auto cond_result_slice,
3461 GetAllocationSlice(cond_result[0]));
3462
3463 return std::unique_ptr<Thunk>(
3464 new WhileThunk(thunk_info, cond_result_slice,
3465 ir_emitter_condition->ConsumeThunkSequence(),
3466 ir_emitter_body->ConsumeThunkSequence()));
3467 }
3468
BuildForThunk(mlir::lmhlo::WhileOp while_op,const Thunk::ThunkInfo & thunk_info,const int64_t loop_limit)3469 StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildForThunk(
3470 mlir::lmhlo::WhileOp while_op, const Thunk::ThunkInfo& thunk_info,
3471 const int64_t loop_limit) {
3472 // Generate thunk sequence for while 'body' (will be used a For loop body).
3473 TF_ASSIGN_OR_RETURN(
3474 auto ir_emitter_body,
3475 IrEmitterUnnested::Create(hlo_module_config_, ir_emitter_context_));
3476 TF_RETURN_IF_ERROR(ir_emitter_body->EmitLmhloRegion(&while_op.body()));
3477
3478 return std::unique_ptr<Thunk>(new ForThunk(
3479 thunk_info, loop_limit, ir_emitter_body->ConsumeThunkSequence()));
3480 }
3481
EmitTargetElementLoop(const HloInstruction & hlo,const llvm_ir::ElementGenerator & body_emitter)3482 Status IrEmitterUnnested::EmitTargetElementLoop(
3483 const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter) {
3484 return InternalError("This should be unreachable");
3485 }
3486
3487 // Gets the output offset as calculated from thread_id.x (to be applied to the
3488 // offset calculated from block_id and thread_id.y).
GetStartOffsetX(const KernelMappingScheme & mapping_scheme,llvm::Value * thread_id_x,llvm::Type * index_ty,llvm::IRBuilder<> * b)3489 static llvm::Value* GetStartOffsetX(const KernelMappingScheme& mapping_scheme,
3490 llvm::Value* thread_id_x,
3491 llvm::Type* index_ty,
3492 llvm::IRBuilder<>* b) {
3493 auto constant = [&](int64_t val) {
3494 return llvm::ConstantInt::get(index_ty, val);
3495 };
3496 if (mapping_scheme.GetIndexingOrder() == kStridedIndexingX) {
3497 return thread_id_x;
3498 } else if (mapping_scheme.GetIndexingOrder() == kStridedLinearIndexingX) {
3499 return b->CreateMul(thread_id_x, constant(mapping_scheme.GetVectorSize()));
3500 }
3501 CHECK_EQ(mapping_scheme.GetIndexingOrder(), kLinearIndexingX);
3502 int64_t x_num_steps =
3503 mapping_scheme.GetTileSizeX() / mapping_scheme.GetNumThreadsX();
3504 return b->CreateMul(thread_id_x, constant(x_num_steps));
3505 }
3506
3507 // Calls `emit_elem_function()` `x_num_steps` times. If
3508 // `vector_size`==1, then each element index passed to
3509 // `emit_elem_function()` will be separated by `step_x`. If `vector_size`>1,
3510 // then it must be a multiple of `x_num_steps`. In that case, it
3511 // triggers a different indexing order that is vectorizable by
3512 // LLVM. It generates many groups of calls to `emit_elem_function`. Each
3513 // group is separated by `step_x` elements. Inside a group, elements
3514 // are consecutive. If `check_x_tile_bounds` is true, then it will check
3515 // if the element index is in bound compared to `tile_width` before
3516 // calling `emit_elem_function`.
UnrollInnerTileLoop(bool check_x_tile_bounds,int64_t x_num_steps,int64_t step_x,int64_t vector_size,const string & loop_name,KernelSupportLibrary * ksl,llvm::Value * start_offset_x,llvm::Value * y_loc,llvm::Value * tile_width,const IrArray::Index & source_idx,llvm::IRBuilder<> * b,const IrEmitterUnnested::EmitElementFunction * emit_elem_function)3517 static void UnrollInnerTileLoop(
3518 bool check_x_tile_bounds, int64_t x_num_steps, int64_t step_x,
3519 int64_t vector_size, const string& loop_name, KernelSupportLibrary* ksl,
3520 llvm::Value* start_offset_x, llvm::Value* y_loc, llvm::Value* tile_width,
3521 const IrArray::Index& source_idx, llvm::IRBuilder<>* b,
3522 const IrEmitterUnnested::EmitElementFunction* emit_elem_function) {
3523 llvm::Type* index_ty = tile_width->getType();
3524 auto constant = [&](int64_t val) {
3525 return llvm::ConstantInt::get(index_ty, val);
3526 };
3527 IrArray::Index source_idx_x_base = source_idx.AddOffsetToDim(y_loc, kDimY, b);
3528 for (int64_t j = 0; j < x_num_steps / vector_size; j++) {
3529 for (int64_t i = 0; i < vector_size; i++) {
3530 int64_t linear_index = j * vector_size + i;
3531 llvm::Value* x_loc = b->CreateAdd(constant(j * step_x * vector_size + i),
3532 start_offset_x, "x_loc");
3533 IrArray::Index source_idx_x = source_idx_x_base.AddOffsetToDim(
3534 constant(j * step_x * vector_size + i), kDimX, b);
3535 auto emit_element = [&] {
3536 return (*emit_elem_function)(source_idx_x, y_loc, x_loc, linear_index);
3537 };
3538 if (check_x_tile_bounds) {
3539 ksl->If(loop_name + "_x_in_tile", b->CreateICmpULT(x_loc, tile_width),
3540 emit_element);
3541 } else {
3542 emit_element();
3543 }
3544 }
3545 }
3546 }
3547
EmitTile(const KernelMappingScheme & mapping_scheme,const IrArray::Index & tile_origin_index,const string & loop_name,KernelSupportLibrary * ksl,const ThreadIdInfo & thread_id_info,llvm::Value * tile_height,llvm::Value * tile_width,const IrEmitterUnnested::EmitElementFunction & emit_elem_function)3548 void IrEmitterUnnested::EmitTile(
3549 const KernelMappingScheme& mapping_scheme,
3550 const IrArray::Index& tile_origin_index, const string& loop_name,
3551 KernelSupportLibrary* ksl, const ThreadIdInfo& thread_id_info,
3552 llvm::Value* tile_height, llvm::Value* tile_width,
3553 const IrEmitterUnnested::EmitElementFunction& emit_elem_function) {
3554 llvm::Type* index_ty = tile_width->getType();
3555 auto constant = [&](int64_t val) {
3556 return llvm::ConstantInt::get(index_ty, val);
3557 };
3558 int64_t num_threads_x = mapping_scheme.GetNumThreadsX();
3559 llvm::Value* num_threads_y = constant(mapping_scheme.GetNumThreadsY());
3560 int64_t tile_size_x = mapping_scheme.GetTileSizeX();
3561
3562 int64_t x_num_steps = tile_size_x / num_threads_x;
3563 llvm::Value* start_offset_x = GetStartOffsetX(
3564 mapping_scheme, thread_id_info.thread_id_x, index_ty, &b_);
3565
3566 // Using dilated mapping scheme, each thread steps with a stride of number
3567 // of threads.
3568 // Otherwise, the stride is one, but we multiply each offset by the limit of
3569 // number of steps which can be made.
3570 int64_t step_x =
3571 mapping_scheme.GetIndexingOrder() == kLinearIndexingX ? 1 : num_threads_x;
3572 int64_t vector_size = mapping_scheme.GetVectorSize();
3573
3574 IrArray::Index source_idx =
3575 tile_origin_index.AddOffsetToDim(start_offset_x, kDimX, &b_);
3576
3577 // True iff all threads always execute all instructions in the tiling
3578 // dimension X.
3579 bool x_tile_fits =
3580 mapping_scheme.GetDimsInElems()[kDimX] % tile_size_x == 0 &&
3581 mapping_scheme.GetRowContiguous();
3582
3583 ksl->For(
3584 loop_name + "_y_in_tile",
3585 /*start=*/thread_id_info.thread_id_y,
3586 /*end=*/
3587 tile_height,
3588 /*step=*/num_threads_y, [&](llvm::Value* y_loc) {
3589 auto unroll_inner_tile_loop = [&](bool check_x_tile_bounds) {
3590 return UnrollInnerTileLoop(check_x_tile_bounds, x_num_steps, step_x,
3591 vector_size, loop_name, ksl,
3592 start_offset_x, y_loc, tile_width,
3593 source_idx, &b_, &emit_elem_function);
3594 };
3595
3596 // Only take this path when we unroll in a way vectorizable by
3597 // LLVM. Special case when the tile doesn't fit completely for even
3598 // row size. For odd row size every other row isn't aligned to the
3599 // vectorized size, so it can't be vectorized by LLVM.
3600 if (!x_tile_fits &&
3601 mapping_scheme.GetIndexingOrder() == kStridedLinearIndexingX) {
3602 ksl->If(
3603 loop_name + "_is_full_tile",
3604 // For the last block, tile_width will be the number of
3605 // elements left.
3606 b_.CreateICmpEQ(constant(mapping_scheme.GetTileSizeX()),
3607 tile_width),
3608 [&] { unroll_inner_tile_loop(/*check_x_tile_bounds=*/false); },
3609 [&] { unroll_inner_tile_loop(/*check_x_tile_bounds=*/true); });
3610 } else {
3611 unroll_inner_tile_loop(/*check_x_tile_bounds=*/!x_tile_fits);
3612 }
3613 });
3614 }
3615
GetUnnormalizedIndex(const IrArray::Index & normalized_shape_index,const Shape & unnormalized_shape,llvm::IRBuilder<> * b_,const KernelMappingScheme & kernel_mapping_scheme)3616 static IrArray::Index GetUnnormalizedIndex(
3617 const IrArray::Index& normalized_shape_index,
3618 const Shape& unnormalized_shape, llvm::IRBuilder<>* b_,
3619 const KernelMappingScheme& kernel_mapping_scheme) {
3620 DCHECK_EQ(normalized_shape_index.size(), 3);
3621 // If the normalization only add a new dimensions of size 1,
3622 // generate simpler indexing. LLVM doesn't always simplify the more
3623 // complicated indexing and this prevents it from vectorizing some
3624 // cases. We do this only for major_to_minor memory layout.
3625 if (unnormalized_shape.rank() == 2 && unnormalized_shape.has_layout() &&
3626 unnormalized_shape.dimensions()[0] == normalized_shape_index.dims()[1] &&
3627 unnormalized_shape.dimensions()[1] == normalized_shape_index.dims()[2] &&
3628 unnormalized_shape.layout().minor_to_major(1) == 0) {
3629 CHECK_EQ(normalized_shape_index.dims()[0], 1);
3630 auto multidim = normalized_shape_index.multidim();
3631 return IrArray::Index({multidim[1], multidim[2]}, unnormalized_shape,
3632 normalized_shape_index.GetType());
3633 }
3634 if (unnormalized_shape.rank() == 2 && unnormalized_shape.has_layout() &&
3635 unnormalized_shape.dimensions()[0] == normalized_shape_index.dims()[2] &&
3636 unnormalized_shape.dimensions()[1] == normalized_shape_index.dims()[1] &&
3637 unnormalized_shape.layout().minor_to_major(1) == 1) {
3638 CHECK_EQ(normalized_shape_index.dims()[0], 1);
3639 auto multidim = normalized_shape_index.multidim();
3640 return IrArray::Index({multidim[2], multidim[1]}, unnormalized_shape,
3641 normalized_shape_index.GetType());
3642 }
3643 llvm::Value* linear = normalized_shape_index.Linearize(
3644 kernel_mapping_scheme.GetDimsInElems(), b_);
3645 return IrArray::Index(linear, unnormalized_shape, b_);
3646 }
3647
3648 // Emits code to process a tensor element in a tile for the given kLoop fusion
3649 // HLO containing parameters that are 0-2-1 transpose of its outputs.
3650 //
3651 // index: The index for the first output element in the normalized tensor, that
3652 // is the resulting tensor after collapsing contiguous dimensions that play
3653 // the same role in the transpose.
3654 // kernel_info: Other information to support the kernel code generation.
EmitTileElementForFusion(mlir::lmhlo::FusionOp fusion,absl::Span<const llvm_ir::IrArray> operand_arrays,absl::Span<const llvm_ir::IrArray> output_arrays,const llvm_ir::IrArray::Index & index,const KernelMappingScheme & mapping_scheme,llvm::Value * y_loc,llvm::Value * x_loc,absl::Span<llvm::Value * const> param_shmem_buffers)3655 void IrEmitterUnnested::EmitTileElementForFusion(
3656 mlir::lmhlo::FusionOp fusion,
3657 absl::Span<const llvm_ir::IrArray> operand_arrays,
3658 absl::Span<const llvm_ir::IrArray> output_arrays,
3659 const llvm_ir::IrArray::Index& index,
3660 const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
3661 llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers) {
3662 const HloComputation* fused_computation =
3663 *GetOrCreateSubComputationFromRegion(&fusion.region(),
3664 /*is_fusion=*/true);
3665 GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
3666 GetNestedComputer());
3667 FusedIrEmitter fused_emitter(&elem_emitter);
3668 for (int i = 0; i < operand_arrays.size(); i++) {
3669 llvm_ir::ElementGenerator gen;
3670 if (llvm::Value* param_tile_buffer = param_shmem_buffers[i]) {
3671 gen = [this, param_tile_buffer, x_loc,
3672 y_loc](llvm_ir::IrArray::Index index) {
3673 // TODO(jlebar): Add AA metadata to this load. Tile buffers are
3674 // global variables, so LLVM's points-to analysis doesn't help us
3675 // much. And we want the AA info to be present before address
3676 // spaces are inferred (which is pretty late in the pipeline), so
3677 // even if we had address-space-based AA in LLVM, it wouldn't help
3678 // us much here.
3679 return b_.CreateLoad(
3680 b_.CreateGEP(param_tile_buffer,
3681 {index.GetConstantWithIndexType(0), x_loc, y_loc}),
3682 "tiled_buffer");
3683 };
3684 } else {
3685 auto array = operand_arrays[i];
3686 auto name = fused_computation->parameter_instruction(i)->name();
3687 gen = [this, array, name](const llvm_ir::IrArray::Index& index) {
3688 return array.EmitReadArrayElement(index, &b_, name);
3689 };
3690 }
3691 fused_emitter.BindGenerator(fused_computation->parameter_instruction(i),
3692 std::move(gen));
3693 }
3694 IrArray::Index untiled_index = GetUnnormalizedIndex(
3695 index, output_arrays[0].GetShape(), &b_, mapping_scheme);
3696 llvm_ir::ElementGenerator output_generator =
3697 *fused_emitter.GetGenerator(fused_computation->root_instruction());
3698 llvm::Value* output_value = output_generator(untiled_index).ValueOrDie();
3699 if (output_arrays.size() > 1) {
3700 DCHECK(output_value->getType()->isStructTy());
3701 DCHECK_EQ(output_value->getType()->getStructNumElements(),
3702 output_arrays.size());
3703 for (int64_t i = 0; i < output_arrays.size(); ++i) {
3704 output_arrays[i].EmitWriteArrayElement(
3705 untiled_index, ExtractValue(output_value, i), &b_);
3706 }
3707 } else {
3708 output_arrays[0].EmitWriteArrayElement(untiled_index, output_value, &b_);
3709 }
3710 }
3711
GetReduceFromUnnestedMlir(mlir::Operation * unnested_hlo,int index)3712 static mlir::Operation* GetReduceFromUnnestedMlir(mlir::Operation* unnested_hlo,
3713 int index) {
3714 auto fusion = mlir::cast<mlir::lmhlo::FusionOp>(unnested_hlo);
3715 auto results = fusion.getFusionResults();
3716 CHECK(index < results.size())
3717 << MlirToString(unnested_hlo) << " vs " << index;
3718 return results[index].getDefiningOp();
3719 }
3720
EmitPrologueForReduction(mlir::Operation * unnested_hlo,absl::Span<const int> instr_index_group,HloComputation * fused_computation,FusedIrEmitter * fused_emitter,absl::Span<const llvm_ir::IrArray> result_ir_arrays,ReductionCodegenState * reduction_info,const FusionLayoutAnalysis & layout_analysis)3721 void IrEmitterUnnested::EmitPrologueForReduction(
3722 mlir::Operation* unnested_hlo, absl::Span<const int> instr_index_group,
3723 HloComputation* fused_computation, FusedIrEmitter* fused_emitter,
3724 absl::Span<const llvm_ir::IrArray> result_ir_arrays,
3725 ReductionCodegenState* reduction_info,
3726 const FusionLayoutAnalysis& layout_analysis) {
3727 VLOG(10) << "Emit prologue for reduction: " << MlirToString(unnested_hlo);
3728 mlir::Operation* first_reduce = nullptr;
3729 for (int index : instr_index_group) {
3730 mlir::Operation* reduce_inst =
3731 GetReduceFromUnnestedMlir(unnested_hlo, index);
3732
3733 if (!IsReductionFromOrToContiguousDimensions(reduce_inst,
3734 layout_analysis)) {
3735 continue;
3736 }
3737
3738 auto results = GetHloOutputs(reduce_inst);
3739 CHECK_EQ(1, results.size());
3740 Shape reduce_inst_shape = layout_analysis.GetShape(results[0]);
3741
3742 VLOG(10) << "Emit prologue for reduction: " << MlirToString(reduce_inst);
3743 if (first_reduce == nullptr) {
3744 first_reduce = reduce_inst;
3745 } else {
3746 CHECK(absl::c_equal(
3747 first_reduce->getAttrOfType<mlir::DenseIntElementsAttr>("dimensions"),
3748 reduce_inst->getAttrOfType<mlir::DenseIntElementsAttr>(
3749 "dimensions")));
3750 }
3751
3752 AddressVector* reduction_input_addresses =
3753 reduction_info->GetMutableReductionInputAddresses();
3754 llvm::Type* element_type = llvm_ir::PrimitiveTypeToIrType(
3755 reduce_inst_shape.element_type(), ir_emitter_context_->llvm_module());
3756 llvm::AllocaInst* reduction_input_address =
3757 llvm_ir::EmitAllocaAtFunctionEntry(element_type,
3758 "reduction_input_address", &b_);
3759 reduction_input_addresses->push_back(reduction_input_address);
3760
3761 int num_partial_results = reduction_info->GetNumPartialResults();
3762 AddressVector* partial_result_addresses =
3763 reduction_info->GetMutablePartialResultAddresses();
3764 llvm::AllocaInst* partial_result_address =
3765 llvm_ir::EmitAllocaAtFunctionEntryWithCount(
3766 element_type, /*element_count=*/b_.getInt32(num_partial_results),
3767 ("partial_reduction_result." + llvm::Twine(index)).str(), &b_);
3768 partial_result_addresses->push_back(partial_result_address);
3769
3770 // Initialize the partial result with the initial value of the reduction.
3771 llvm::Value* init_ir_value;
3772 const HloInstruction* reduce_hlo = fused_computation->root_instruction();
3773 if (reduce_hlo->opcode() == HloOpcode::kTuple) {
3774 reduce_hlo = reduce_hlo->operand(index);
3775 }
3776 const HloInstruction* init_value = reduce_hlo->operand(1);
3777 init_ir_value = (*fused_emitter->GetGenerator(init_value))(
3778 IrArray::Index(b_.getInt32Ty()))
3779 .ValueOrDie();
3780
3781 for (int i = 0; i < num_partial_results; ++i) {
3782 b_.CreateStore(init_ir_value, b_.CreateInBoundsGEP(partial_result_address,
3783 {b_.getInt32(i)}));
3784 }
3785 reduction_info->GetMutableInitialValues()->push_back(init_ir_value);
3786
3787 auto& mapping_scheme = reduction_info->GetKernelMappingScheme();
3788 int64_t num_threads_x = mapping_scheme.GetNumThreadsX();
3789 llvm::Type* primitive_type = llvm_ir::PrimitiveTypeToIrType(
3790 reduce_inst_shape.element_type(), module_);
3791 llvm::Type* buffer_type = [&] {
3792 if (reduction_info->IsRowReduction()) {
3793 // Allocate __shared__ cache[num_partial_results][kWarpSize].
3794 // TODO(cheshire): Do we need the same trick as below to avoid bank
3795 // conflicts?
3796 return llvm::ArrayType::get(
3797 llvm::ArrayType::get(primitive_type, kWarpSize),
3798 num_partial_results);
3799 } else {
3800 // Allocate __shared__
3801 // cache[num_partial_results][num_threads][num_threads + 1], where
3802 // num_threads == num_threads_x == num_threads_y. The "+1" is used to
3803 // avoid bank conflicts.
3804 CHECK_EQ(num_threads_x, mapping_scheme.GetNumThreadsY());
3805 return llvm::ArrayType::get(
3806 llvm::ArrayType::get(
3807 llvm::ArrayType::get(primitive_type, num_threads_x + 1),
3808 num_threads_x),
3809 num_partial_results);
3810 }
3811 }();
3812 llvm::GlobalVariable* shared_cache_per_reduce =
3813 llvm_ir::AllocateSharedMemoryTile(b_.GetInsertBlock()->getModule(),
3814 buffer_type,
3815 absl::StrCat("shared_cache_", index));
3816 reduction_info->GetMutableSharedCache()->push_back(shared_cache_per_reduce);
3817 }
3818 CHECK(first_reduce);
3819 }
3820
EmitFullWarpShuffleDownLoopForAllReduces(absl::Span<HloComputation * const> reducers,absl::Span<llvm::AllocaInst * const> partial_result_addresses,int threads_per_block)3821 void IrEmitterUnnested::EmitFullWarpShuffleDownLoopForAllReduces(
3822 absl::Span<HloComputation* const> reducers,
3823 absl::Span<llvm::AllocaInst* const> partial_result_addresses,
3824 int threads_per_block) {
3825 CHECK_EQ(reducers.size(), partial_result_addresses.size());
3826 for (int i = 0; i != reducers.size(); i++) {
3827 EmitFullWarpShuffleDownLoopForReduce(
3828 reducers[i], partial_result_addresses[i]->getType()->getElementType(),
3829 partial_result_addresses[i], threads_per_block);
3830 }
3831 }
3832
EmitFullWarpShuffleDownLoopForReduce(HloComputation * reducer,llvm::Type * element_type,llvm::Value * partial_result_address,int threads_per_block)3833 void IrEmitterUnnested::EmitFullWarpShuffleDownLoopForReduce(
3834 HloComputation* reducer, llvm::Type* element_type,
3835 llvm::Value* partial_result_address, int threads_per_block) {
3836 // This only works when the block size is a multiple of 32 threads.
3837 CHECK_EQ(threads_per_block % 32, 0);
3838 for (int distance = 16; distance >= 1; distance /= 2) {
3839 int bit_width = llvm_ir::GetSizeInBits(element_type);
3840 llvm::Value* result_from_other_lane = llvm_ir::EmitAllocaAtFunctionEntry(
3841 element_type, "result_from_other_lane", &b_);
3842 // Bitcast cannot be applied to aggregate types (even packed ones), so
3843 // we bitcast addresses of load/store to intN* of the same bit-width.
3844 llvm::Type* shuffled_value_type =
3845 element_type->isStructTy() ? b_.getIntNTy(bit_width) : element_type;
3846 auto convert_pointer_for_shuffle = [&](llvm::Value* ptr) {
3847 return b_.CreatePointerBitCastOrAddrSpaceCast(
3848 ptr, shuffled_value_type->getPointerTo());
3849 };
3850 llvm::Value* partial_result =
3851 Load(convert_pointer_for_shuffle(partial_result_address),
3852 "partial_reduction_result");
3853 Store(EmitFullWarpShuffleDown(partial_result, b_.getInt32(distance), &b_),
3854 convert_pointer_for_shuffle(result_from_other_lane));
3855 TF_CHECK_OK(EmitCallToNestedComputation(
3856 *reducer, {partial_result_address, result_from_other_lane},
3857 partial_result_address));
3858 }
3859 }
3860
3861 // Given the IrArray index of a reduction input, returns the linear address of
3862 // the reduction output as if the reduction were going to keep the input shape
3863 // with the dimensions being reduced moved.
GetUntransposedOutputLinearAddress(llvm::IRBuilder<> * b,const llvm_ir::IrArray::Index & index,const ReductionCodegenState & reduction_info)3864 static llvm::Value* GetUntransposedOutputLinearAddress(
3865 llvm::IRBuilder<>* b, const llvm_ir::IrArray::Index& index,
3866 const ReductionCodegenState& reduction_info) {
3867 const KernelMappingScheme& kernel_mapping_scheme =
3868 reduction_info.GetKernelMappingScheme();
3869 if (reduction_info.IsRowReduction()) {
3870 // For row-reduction, y-coordinate determines which row we write into.
3871 return index[kDimY];
3872 }
3873 // For column reduction, we get the transposed address.
3874 absl::Span<const int64> dims_in_elem = kernel_mapping_scheme.GetDimsInElems();
3875 llvm::Value* x_dim_size = index.GetConstantWithIndexType(dims_in_elem[kDimX]);
3876 llvm::Value* x_block_offset = b->CreateMul(index[kDimZ], x_dim_size);
3877 return b->CreateAdd(x_block_offset, index[kDimX]);
3878 }
3879
EmitEpilogueForReduction(llvm::Type * index_ty,mlir::Operation * unnested_hlo,absl::Span<const int> instr_index_group,absl::Span<const llvm_ir::IrArray> result_ir_arrays,absl::Span<HloComputation * const> reducers,const ReductionCodegenState & reduction_info,const TilingKernelInfo & tiling_kernel_info,const FusionLayoutAnalysis & layout_analysis)3880 void IrEmitterUnnested::EmitEpilogueForReduction(
3881 llvm::Type* index_ty, mlir::Operation* unnested_hlo,
3882 absl::Span<const int> instr_index_group,
3883 absl::Span<const llvm_ir::IrArray> result_ir_arrays,
3884 absl::Span<HloComputation* const> reducers,
3885 const ReductionCodegenState& reduction_info,
3886 const TilingKernelInfo& tiling_kernel_info,
3887 const FusionLayoutAnalysis& layout_analysis) {
3888 const KernelMappingScheme& mapping_scheme =
3889 reduction_info.GetKernelMappingScheme();
3890 auto constant = [&](uint64 c) -> llvm::Constant* {
3891 return llvm::ConstantInt::get(index_ty, c);
3892 };
3893
3894 IrEmitterUnnested::ThreadIdInfo thread_id_info =
3895 EmitThreadIdInfo(mapping_scheme.GetThreadsPerBlock(), index_ty,
3896 mapping_scheme.GetNumThreadsX());
3897
3898 IrArray::Index start_offset = [&] {
3899 llvm::Value* x_loc = thread_id_info.thread_id_x;
3900 llvm::Value* y_loc = thread_id_info.thread_id_y;
3901 if (!reduction_info.IsRowReduction()) {
3902 std::swap(x_loc, y_loc);
3903 }
3904 llvm::Value* start_offset_x =
3905 GetStartOffsetX(mapping_scheme, x_loc, index_ty, &b_);
3906 return tiling_kernel_info.tile_origin.AddOffsetToDim(y_loc, kDimY, &b_)
3907 .AddOffsetToDim(start_offset_x, kDimX, &b_);
3908 }();
3909
3910 absl::Span<llvm::AllocaInst* const> partial_result_addresses =
3911 reduction_info.GetPartialResultAddresses();
3912
3913 int reduction_idx = -1;
3914
3915 // `instruction_idx` is indexing over all instructions in a group, some of
3916 // them might not be unnested reductions.
3917 for (int instruction_idx : instr_index_group) {
3918 mlir::Operation* reduce_hlo =
3919 GetReduceFromUnnestedMlir(unnested_hlo, instruction_idx);
3920 llvm_ir::IrArray output_array = result_ir_arrays[instruction_idx];
3921 if (!IsReductionFromOrToContiguousDimensions(reduce_hlo, layout_analysis)) {
3922 continue;
3923 }
3924 reduction_idx++;
3925
3926 Shape operand_shape = layout_analysis.GetShape(reduce_hlo->getOperand(0));
3927 Shape reduction_kept_element_shape = ShapeUtil::FilterDimensions(
3928 [&](int64_t dim) {
3929 return !absl::c_linear_search(
3930 reduce_hlo->getAttrOfType<mlir::DenseIntElementsAttr>(
3931 "dimensions"),
3932 dim);
3933 },
3934 operand_shape);
3935
3936 for (int partial_result_idx = 0;
3937 partial_result_idx < reduction_info.GetNumPartialResults();
3938 ++partial_result_idx) {
3939 llvm::Value* untransposed_output_linear_address =
3940 GetUntransposedOutputLinearAddress(
3941 &b_,
3942 start_offset.AddOffsetToDim(constant(partial_result_idx), kDimX,
3943 &b_),
3944 reduction_info);
3945
3946 // A reduction is allowed to transpose its output. For example, suppose
3947 // we are reducing the second dimension of f32[10,20,30]{3,2,1}. We are
3948 // allowed to produce as output either f32[10,30]{1,0} (no transpose) or
3949 // f32[10,30]{0,1} (transposing the two output dims).
3950 //
3951 // At this point in the function we have a "partial sum" of input elements
3952 // (stored in partial_result_addresses), and we need to accumulate it into
3953 // the correct output element.
3954 IrArray::Index element_index(
3955 /*linear=*/untransposed_output_linear_address,
3956 reduction_kept_element_shape, &b_);
3957 IrArray::Index output_index(element_index.multidim(),
3958 output_array.GetShape(),
3959 element_index.GetType());
3960 llvm::Value* output_address = output_array.EmitArrayElementAddress(
3961 output_index, &b_, "output_element_address");
3962 llvm::Value* current_output = b_.CreateInBoundsGEP(
3963 partial_result_addresses[reduction_idx],
3964 {constant(partial_result_idx)}, "current_output");
3965
3966 llvm::Type* element_type =
3967 partial_result_addresses[reduction_idx]->getType()->getElementType();
3968 if (reduction_info.IsRowReduction()) {
3969 EmitEpilogueForRowReduction(reducers[reduction_idx], thread_id_info,
3970 reduction_info, element_type, index_ty,
3971 current_output, output_address,
3972 reduction_idx, partial_result_idx);
3973 } else {
3974 EmitEpilogueForColumnReduction(
3975 reducers[reduction_idx], thread_id_info, reduction_info,
3976 element_type, index_ty, current_output, output_address,
3977 reduction_idx, partial_result_idx, tiling_kernel_info);
3978 }
3979 }
3980 }
3981 }
3982
EmitBlockId()3983 llvm::Value* IrEmitterUnnested::EmitBlockId() {
3984 return gpu::EmitCallToTargetIntrinsic(gpu::TargetIntrinsicID::kBlockIdx, {},
3985 {}, &b_);
3986 }
3987
EmitPrintfWithThreadId(absl::string_view fmt,absl::Span<llvm::Value * const> arguments,absl::optional<int64> thread_id_filter,absl::optional<int64> block_id_filter)3988 void IrEmitterUnnested::EmitPrintfWithThreadId(
3989 absl::string_view fmt, absl::Span<llvm::Value* const> arguments,
3990 absl::optional<int64> thread_id_filter,
3991 absl::optional<int64> block_id_filter) {
3992 llvm::Value* thread_id = EmitThreadId(1024, b_.getInt32Ty());
3993 llvm::Value* block_id = EmitBlockId();
3994 std::vector<llvm::Value*> updated_arguments = {thread_id, block_id};
3995 updated_arguments.insert(updated_arguments.end(), arguments.begin(),
3996 arguments.end());
3997 llvm::Value* constraint = b_.getTrue();
3998 if (thread_id_filter) {
3999 constraint = b_.CreateAnd(
4000 constraint, b_.CreateICmpEQ(thread_id, b_.getInt32(*thread_id_filter)));
4001 }
4002 if (block_id_filter) {
4003 constraint = b_.CreateAnd(
4004 constraint, b_.CreateICmpEQ(block_id, b_.getInt32(*block_id_filter)));
4005 }
4006 KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
4007 ksl.If(constraint, [&] {
4008 xla::gpu::EmitPrintf(absl::StrCat("[TID=%d,BID=%d] ", fmt, "\n"),
4009 updated_arguments, &b_);
4010 });
4011 }
CastSharedToGlobal(llvm::Value * input,llvm::Twine name)4012 llvm::Value* IrEmitterUnnested::CastSharedToGlobal(llvm::Value* input,
4013 llvm::Twine name) {
4014 return b_.CreateAddrSpaceCast(
4015 input,
4016 llvm::PointerType::get(input->getType()->getPointerElementType(),
4017 /*AddressSpace=*/0),
4018 name);
4019 }
4020
EmitEpilogueForRowReduction(HloComputation * reducer,const IrEmitterUnnested::ThreadIdInfo & thread_id_info,const ReductionCodegenState & reduction_info,llvm::Type * element_type,llvm::Type * index_ty,llvm::Value * current_output,llvm::Value * output_address,int reduction_idx,int partial_result_idx)4021 void IrEmitterUnnested::EmitEpilogueForRowReduction(
4022 HloComputation* reducer,
4023 const IrEmitterUnnested::ThreadIdInfo& thread_id_info,
4024 const ReductionCodegenState& reduction_info, llvm::Type* element_type,
4025 llvm::Type* index_ty, llvm::Value* current_output,
4026 llvm::Value* output_address, int reduction_idx, int partial_result_idx) {
4027 auto constant = [&](uint64 c) -> llvm::Constant* {
4028 return llvm::ConstantInt::get(index_ty, c);
4029 };
4030 auto is_zero = [&](llvm::Value* value) {
4031 return b_.CreateICmpEQ(value, constant(0));
4032 };
4033 llvm::GlobalVariable* shared_cache =
4034 reduction_info.GetSharedCache()[reduction_idx];
4035 KernelSupportLibrary ksl(&b_);
4036 const KernelMappingScheme& mapping_scheme =
4037 reduction_info.GetKernelMappingScheme();
4038
4039 EmitFullWarpShuffleDownLoopForReduce(reducer, element_type, current_output,
4040 mapping_scheme.GetThreadsPerBlock());
4041 llvm::Value* warp_id =
4042 b_.CreateUDiv(thread_id_info.thread_id_x, constant(kWarpSize));
4043 ksl.If("intra_warp_reduce_write", is_zero(thread_id_info.lane_id), [&] {
4044 llvm::Value* shmem_output_addr = CastSharedToGlobal(b_.CreateInBoundsGEP(
4045 shared_cache, {b_.getInt32(0), constant(partial_result_idx), warp_id}));
4046 b_.CreateStore(b_.CreateLoad(current_output), shmem_output_addr);
4047 });
4048
4049 // TODO(cheshire): Don't we want to sync it once for everything in the
4050 // output? Not once per each?
4051 EmitSyncThreads();
4052 ksl.If("inter_warp_reduce", is_zero(warp_id), [&] {
4053 llvm::Value* block_accum_addr = CastSharedToGlobal(b_.CreateInBoundsGEP(
4054 shared_cache, {b_.getInt32(0), constant(partial_result_idx),
4055 thread_id_info.lane_id}));
4056 llvm::Value* initial_value =
4057 reduction_info.GetInitialValues()[reduction_idx];
4058 llvm::Value* initial_value_addr =
4059 CastSharedToGlobal(llvm_ir::EmitAllocaAtFunctionEntry(
4060 element_type, "initial_value_addr", &b_));
4061 b_.CreateStore(initial_value, initial_value_addr);
4062
4063 llvm::Value* warp_exists =
4064 b_.CreateICmpULT(thread_id_info.thread_id_x,
4065 constant(mapping_scheme.GetNumThreadsX() / kWarpSize));
4066
4067 llvm::Value* selected_value =
4068 b_.CreateSelect(warp_exists, block_accum_addr, initial_value_addr);
4069
4070 EmitFullWarpShuffleDownLoopForReduce(reducer, element_type,
4071 /*block_accum_addr*/ selected_value,
4072 mapping_scheme.GetThreadsPerBlock());
4073 ksl.If("reduction_write_output", is_zero(thread_id_info.thread_id_x), [&] {
4074 if (reduction_info.IsRaceFree()) {
4075 VLOG(10) << "Using deterministic reductions: writing out "
4076 "the value directly";
4077 b_.CreateStore(b_.CreateLoad(block_accum_addr, "output"),
4078 output_address);
4079 } else {
4080 TF_CHECK_OK(EmitAtomicOperationForNestedComputation(
4081 *reducer, output_address, block_accum_addr));
4082 }
4083 });
4084 });
4085 }
4086
EmitEpilogueForColumnReduction(HloComputation * reducer,const IrEmitterUnnested::ThreadIdInfo & thread_id_info,const ReductionCodegenState & reduction_info,llvm::Type * element_type,llvm::Type * index_ty,llvm::Value * current_output,llvm::Value * output_address,int reduction_idx,int partial_result_idx,const TilingKernelInfo & tiling_kernel_info)4087 void IrEmitterUnnested::EmitEpilogueForColumnReduction(
4088 HloComputation* reducer,
4089 const IrEmitterUnnested::ThreadIdInfo& thread_id_info,
4090 const ReductionCodegenState& reduction_info, llvm::Type* element_type,
4091 llvm::Type* index_ty, llvm::Value* current_output,
4092 llvm::Value* output_address, int reduction_idx, int partial_result_idx,
4093 const TilingKernelInfo& tiling_kernel_info) {
4094 KernelSupportLibrary ksl(&b_);
4095 llvm::GlobalVariable* shared_cache =
4096 reduction_info.GetSharedCache()[reduction_idx];
4097 auto constant = [&](uint64 c) -> llvm::Constant* {
4098 return llvm::ConstantInt::get(index_ty, c);
4099 };
4100 auto is_zero = [&](llvm::Value* value) {
4101 return b_.CreateICmpEQ(value, constant(0));
4102 };
4103 const KernelMappingScheme& mapping_scheme =
4104 reduction_info.GetKernelMappingScheme();
4105 llvm::Value* shmem_output_addr = CastSharedToGlobal(
4106 b_.CreateInBoundsGEP(
4107 shared_cache,
4108 {b_.getInt32(0), constant(partial_result_idx),
4109 thread_id_info.thread_id_x, thread_id_info.thread_id_y}),
4110 "shmem_output_address");
4111 llvm::Value* current_output_value = b_.CreateLoad(current_output);
4112 b_.CreateStore(current_output_value, shmem_output_addr);
4113
4114 EmitSyncThreads();
4115
4116 // Get transposed element from shared memory.
4117 llvm::Value* shmem_transposed_addr = CastSharedToGlobal(b_.CreateInBoundsGEP(
4118 shared_cache,
4119 {b_.getInt32(0), constant(partial_result_idx), thread_id_info.thread_id_y,
4120 thread_id_info.thread_id_x},
4121 "shmem_transposed_addr"));
4122
4123 EmitFullWarpShuffleDownLoopForReduce(reducer, element_type,
4124 shmem_transposed_addr,
4125 mapping_scheme.GetThreadsPerBlock());
4126
4127 // Some warps in the block are completely outside of the bound of the
4128 // tensor, so they should not write any output at all.
4129 llvm::Value* has_output = b_.CreateAnd(
4130 b_.CreateICmpULT(
4131 GetStartOffsetX(mapping_scheme, thread_id_info.thread_id_y, index_ty,
4132 &b_),
4133 tiling_kernel_info.output_tile_bounds[kDimX]),
4134 b_.CreateICmpULT(thread_id_info.thread_id_x,
4135 tiling_kernel_info.output_tile_bounds[kDimY]));
4136
4137 ksl.If("reduction_write_output",
4138 b_.CreateAnd(has_output, is_zero(thread_id_info.lane_id)), [&] {
4139 if (reduction_info.IsRaceFree()) {
4140 VLOG(10) << "Using deterministic reductions: writing out "
4141 "the value directly";
4142 b_.CreateStore(
4143 b_.CreateLoad(shmem_transposed_addr, "output_value"),
4144 output_address);
4145 } else {
4146 TF_CHECK_OK(EmitAtomicOperationForNestedComputation(
4147 *reducer, output_address, shmem_transposed_addr));
4148 }
4149 });
4150 }
4151
EmitTileElementForReduction(mlir::Operation * unnested_hlo,const Shape & reduction_operand_shape,absl::Span<const int> instr_index_group,HloComputation * fused_computation,FusedIrEmitter * fused_emitter,absl::Span<const llvm_ir::IrArray> result_ir_arrays,absl::Span<HloComputation * const> reducers,const llvm_ir::IrArray::Index & index,const ReductionCodegenState & reduction_info,int64_t x_iter_num,const FusionLayoutAnalysis & layout_analysis)4152 void IrEmitterUnnested::EmitTileElementForReduction(
4153 mlir::Operation* unnested_hlo, const Shape& reduction_operand_shape,
4154 absl::Span<const int> instr_index_group, HloComputation* fused_computation,
4155 FusedIrEmitter* fused_emitter,
4156 absl::Span<const llvm_ir::IrArray> result_ir_arrays,
4157 absl::Span<HloComputation* const> reducers,
4158 const llvm_ir::IrArray::Index& index,
4159 const ReductionCodegenState& reduction_info, int64_t x_iter_num,
4160 const FusionLayoutAnalysis& layout_analysis) {
4161 VLOG(10) << "Emit tile element for reduce " << MlirToString(unnested_hlo);
4162 int partial_result_index = reduction_info.IsRowReduction() ? 0 : x_iter_num;
4163
4164 InlinedVector<llvm_ir::ElementGenerator, 1> input_gens;
4165 std::vector<std::pair<llvm_ir::ElementGenerator, int>> extra_output_gens;
4166
4167 // Construct the ElementGenerator for each reduction and extra output in the
4168 // the group of output instructions.
4169 for (int index : instr_index_group) {
4170 mlir::Operation* inst = GetReduceFromUnnestedMlir(unnested_hlo, index);
4171
4172 const HloInstruction* hlo = fused_computation->root_instruction();
4173 if (hlo->opcode() == HloOpcode::kTuple) {
4174 hlo = hlo->operand(index);
4175 }
4176 if (IsReductionFromOrToContiguousDimensions(inst, layout_analysis)) {
4177 input_gens.push_back(*fused_emitter->GetGenerator(hlo->operand(0)));
4178 } else {
4179 extra_output_gens.emplace_back(*fused_emitter->GetGenerator(hlo), index);
4180 }
4181 }
4182
4183 IrArray::Index input_index =
4184 GetUnnormalizedIndex(index, reduction_operand_shape, &b_,
4185 reduction_info.GetKernelMappingScheme());
4186 // Clear the linear index field of the IrArray::Index to enable the use of
4187 // GetElementPointer with array types. This enables the vectorization of
4188 // the computation for different partial results. Use this index if
4189 // 'num_partial_results > 1'.
4190 int num_partial_results = reduction_info.GetNumPartialResults();
4191 auto index_without_linear = IrArray::Index(
4192 input_index.multidim(), reduction_operand_shape, input_index.GetType());
4193
4194 // Emit code to generate the input and perform the reduction computation for
4195 // each reduction instruction.
4196 for (int i = 0; i < reducers.size(); i++) {
4197 llvm::AllocaInst* input_address =
4198 reduction_info.GetReductionInputAddresses()[i];
4199 llvm::AllocaInst* partial_reduction_result_address =
4200 reduction_info.GetPartialResultAddresses()[i];
4201 llvm::Value* const input_ir_value =
4202 input_gens[i](num_partial_results > 1 ? index_without_linear
4203 : input_index)
4204 .ValueOrDie();
4205 Store(input_ir_value, input_address);
4206 llvm::Value* partial_result_address = InBoundsGEP(
4207 partial_reduction_result_address, {b_.getInt32(partial_result_index)});
4208 TF_CHECK_OK(EmitCallToNestedComputation(
4209 *reducers[i], {partial_result_address, input_address},
4210 partial_result_address));
4211 }
4212
4213 // Emit code to generate the output for the non-reduction instructions in the
4214 // fusion, if any.
4215 TF_CHECK_OK(EmitExtraOutputsForReduce(
4216 result_ir_arrays, input_index,
4217 /*use_linear_index=*/num_partial_results == 1, extra_output_gens));
4218 }
4219
EmitThreadId(int64_t threads_per_block,llvm::Type * index_ty)4220 llvm::Value* IrEmitterUnnested::EmitThreadId(int64_t threads_per_block,
4221 llvm::Type* index_ty) {
4222 // Calculate (y, x) coordinates respectively in the 2D view of thread block,
4223 // defined by (num_thread_y, num_thread_x) from thread_id.
4224 llvm::CallInst* thread_id_raw = gpu::EmitCallToTargetIntrinsic(
4225 gpu::TargetIntrinsicID::kThreadIdx, {}, {}, &b_);
4226 llvm_ir::AddRangeMetadata(0, threads_per_block, thread_id_raw);
4227 return b_.CreateIntCast(thread_id_raw, index_ty,
4228 /*isSigned=*/true, "thread.id.x");
4229 }
4230
EmitThreadIdInfo(int64_t threads_per_block,llvm::Type * index_ty,int64_t num_threads_x)4231 IrEmitterUnnested::ThreadIdInfo IrEmitterUnnested::EmitThreadIdInfo(
4232 int64_t threads_per_block, llvm::Type* index_ty, int64_t num_threads_x) {
4233 auto constant = [&](uint64 c) -> llvm::Constant* {
4234 return llvm::ConstantInt::get(index_ty, c);
4235 };
4236 llvm::Value* thread_id = EmitThreadId(threads_per_block, index_ty);
4237 llvm::Value* num_threads_x_v = constant(num_threads_x);
4238 return {
4239 /*thread_id=*/thread_id,
4240 /*thread_id_x=*/b_.CreateURem(thread_id, num_threads_x_v, "thread_id.x"),
4241 /*thread_id_y=*/b_.CreateUDiv(thread_id, num_threads_x_v, "thread_id.y"),
4242 /*lane_id=*/b_.CreateURem(thread_id, constant(kWarpSize), "lane_id")};
4243 }
4244
EmitTilingKernel(const KernelMappingScheme & mapping_scheme,llvm::Type * index_ty,const TileElementGenerator & tile_element_generator)4245 IrEmitterUnnested::TilingKernelInfo IrEmitterUnnested::EmitTilingKernel(
4246 const KernelMappingScheme& mapping_scheme, llvm::Type* index_ty,
4247 const TileElementGenerator& tile_element_generator) {
4248 absl::Span<const int64> dims_in_elems = mapping_scheme.GetDimsInElems();
4249 std::vector<int64> dims_in_blocks = {
4250 CeilOfRatio(dims_in_elems[0], mapping_scheme.GetTileSizeZ()),
4251 CeilOfRatio(dims_in_elems[1], mapping_scheme.GetTileSizeY()),
4252 CeilOfRatio(dims_in_elems[2], mapping_scheme.GetTileSizeX())};
4253 auto constant = [&](uint64 c) -> llvm::Constant* {
4254 return llvm::ConstantInt::get(index_ty, c);
4255 };
4256
4257 IrEmitterUnnested::ThreadIdInfo thread_id_info =
4258 EmitThreadIdInfo(mapping_scheme.GetThreadsPerBlock(), index_ty,
4259 mapping_scheme.GetNumThreadsX());
4260
4261 KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
4262
4263 const IrArray::Index block_coords = [&] {
4264 llvm::Value* block_id = EmitBlockId();
4265 llvm_ir::AddRangeMetadata(0, mapping_scheme.GetNumberOfBlocks(),
4266 llvm::cast<llvm::Instruction>(block_id));
4267 llvm::Value* linear_block_id =
4268 b_.CreateIntCast(block_id, index_ty, /*isSigned=*/true, "block.id.x");
4269 IrArray::Index starting_block(linear_block_id,
4270 ShapeUtil::MakeShapeWithDescendingLayout(
4271 PRED /*arbitrary*/, dims_in_blocks),
4272 &b_);
4273
4274 std::vector<llvm::Value*> multidim = {
4275 b_.CreateMul(starting_block[0], constant(mapping_scheme.GetTileSizeZ()),
4276 "block_origin.z"),
4277 starting_block[1], starting_block[2]};
4278 return IrArray::Index(multidim, dims_in_blocks, index_ty);
4279 }();
4280
4281 std::array<llvm::Value*, 3> output_tile_bounds;
4282 for (int i = kDimY; i < kDimTot; ++i) {
4283 int64_t tile_size_for_dim = mapping_scheme.GetTileSizeFor(i);
4284 // Only last row or column may not have full size.
4285 llvm::Value* is_last =
4286 b_.CreateICmpEQ(block_coords[i], constant(dims_in_blocks[i] - 1));
4287 int64_t partial_row =
4288 dims_in_elems[i] - (dims_in_blocks[i] - 1) * tile_size_for_dim;
4289 output_tile_bounds[i] =
4290 b_.CreateSelect(is_last, constant(partial_row),
4291 constant(tile_size_for_dim), "tile_bound");
4292 }
4293
4294 IrArray::Index tile_origin = [&] {
4295 std::vector<llvm::Value*> elem_multi_index = block_coords.multidim();
4296 llvm::Type* index_ty = block_coords.GetType();
4297 for (int i = kDimY; i < kDimTot; ++i) {
4298 elem_multi_index[i] = b_.CreateMul(
4299 block_coords[i],
4300 llvm::ConstantInt::get(index_ty, mapping_scheme.GetTileSizeFor(i)),
4301 "tile_origin." + std::to_string(i));
4302 }
4303 return IrArray::Index(elem_multi_index, mapping_scheme.GetDimsInElems(),
4304 index_ty);
4305 }();
4306
4307 auto emit_tile = [&](const IrArray::Index& tile) {
4308 tile_element_generator(thread_id_info, tile, "output",
4309 output_tile_bounds[1], output_tile_bounds[2], &ksl);
4310 };
4311
4312 if (mapping_scheme.GetTileSizeZ() == 1) {
4313 emit_tile(tile_origin);
4314 } else {
4315 llvm::Value* starting_tile_index_for_dim = tile_origin[kDimZ];
4316 llvm::Value* block_size_for_dim = constant(mapping_scheme.GetTileSizeZ());
4317 llvm::Value* block_id_for_dim =
4318 b_.CreateUDiv(starting_tile_index_for_dim, block_size_for_dim);
4319 llvm::Value* last_block_for_dim = constant(dims_in_blocks[kDimZ] - 1);
4320 llvm::Value* last_block_size_for_dim =
4321 constant(dims_in_elems[kDimZ] -
4322 (dims_in_blocks[kDimZ] - 1) * mapping_scheme.GetTileSizeZ());
4323
4324 llvm::Value* num_tiles_in_block =
4325 b_.CreateSelect(b_.CreateICmpEQ(last_block_for_dim, block_id_for_dim),
4326 last_block_size_for_dim, block_size_for_dim);
4327 ksl.For("loop_z",
4328 /*start=*/constant(0),
4329 /*end=*/num_tiles_in_block,
4330 /*step=*/1, [&](llvm::Value* block_dim_induction_var) {
4331 IrArray::Index tile_index = tile_origin.AddOffsetToDim(
4332 block_dim_induction_var, kDimZ, &b_);
4333 emit_tile(tile_index);
4334 });
4335 }
4336 return {output_tile_bounds, tile_origin};
4337 }
4338
EmitSyncThreads()4339 llvm::CallInst* IrEmitterUnnested::EmitSyncThreads() {
4340 return EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, &b_);
4341 }
4342
4343 // Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose
4344 // algorithm to improve the memory access patterns for the input parameters
4345 // with a shape that is a 0-2-1 transpose of the output tensor shape. The caller
4346 // is responsible for making sure that it is safe to apply the shared memory
4347 // transpose on the input parameters.
4348 //
4349 //
4350 // For the purpose of tiling, the output tensors have a logical shape of three
4351 // components 0-2-1 while the relevant input parameters have a logical shape
4352 // of three components 0-1-2 in the order major to minor. The x- and y-
4353 // dimensions of the tensors are tiled in square tiles with an edge length
4354 // `kTileSize`. Each thread block of `kTileSize` x `kNumRows` threads
4355 // transposes one tile: each thread copies kTileSize/kNumRows elements from
4356 // the input to a shared memory tile, then the otherwise "regular HLO kernel"
4357 // reads from the shared memory instead of the original input.
4358 //
4359 // This is similar to the following CUDA algorithm in TensorFlow:
4360 // https://goo.gl/MStRV6.
4361 //
4362 // `kTileSize` should usually be same as warp size. We currently choose 32 for
4363 // `kTileSize` and 4 for `kNumRows`. The CUDA algorithm uses 8 for `kNumRows`.
4364 //
4365 // TODO(b/33320379): Here each block transposes 1 tile. It may be more
4366 // efficient to launch fewer blocks so each transposes many tiles.
EmitHlo021Tile(mlir::Operation * op,Thunk * kernel_thunk,const MlirEmitterContext & context,absl::Span<const llvm_ir::IrArray> operand_arrays,absl::Span<const llvm_ir::IrArray> output_arrays,absl::Span<const int64> reduced_output_dims,absl::Span<const int64> tiled_param_ids,const KernelMappingScheme & mapping_scheme,const LaunchDimensions & launch_dimensions)4367 void IrEmitterUnnested::EmitHlo021Tile(
4368 mlir::Operation* op, Thunk* kernel_thunk, const MlirEmitterContext& context,
4369 absl::Span<const llvm_ir::IrArray> operand_arrays,
4370 absl::Span<const llvm_ir::IrArray> output_arrays,
4371 absl::Span<const int64> reduced_output_dims,
4372 absl::Span<const int64> tiled_param_ids,
4373 const KernelMappingScheme& mapping_scheme,
4374 const LaunchDimensions& launch_dimensions) {
4375 std::string name = mlir::GetNameFromLoc(op->getLoc());
4376
4377 llvm::Type* index_type =
4378 GetIndexTypeForKernel(op, launch_dimensions.launch_bound(), &b_);
4379 std::vector<IrArray> param_arrays;
4380
4381 // For each tiled parameter, cast its input IrArray to the corresponding
4382 // reduced shape and keep the reduced shape live during IR emission.
4383 std::vector<IrArray> param_in_reduced_shape_arrays;
4384 std::vector<llvm::Value*> param_shmem_buffers(context.operand_shapes.size(),
4385 nullptr);
4386
4387 auto get_shared_memory_buffer = [&](llvm::Type* elem_ty,
4388 absl::string_view buffer_name) {
4389 // For Nvidia GPUs, the warp size is 32 threads and the shared memory bank
4390 // is organized into 32-way. We usually use the warp size or a multiplier or
4391 // a the warp size as the size for tiling. This may cause all elements in
4392 // the same column of a tile use the same memory bank and therefore shared
4393 // memory bank conflicts. Adding 1 to the minor dimension of the shared
4394 // memory buffer can reduce such shared memory bank conflicts.
4395 llvm::Type* buffer_type = llvm::ArrayType::get(
4396 llvm::ArrayType::get(elem_ty, mapping_scheme.GetTileSizeX() + 1),
4397 mapping_scheme.GetTileSizeY());
4398 return llvm_ir::AllocateSharedMemoryTile(b_.GetInsertBlock()->getModule(),
4399 buffer_type, buffer_name);
4400 };
4401
4402 for (int64_t id = 0; id < context.operand_shapes.size(); id++) {
4403 const Shape& param_shape = context.operand_shapes[id];
4404 param_arrays.push_back(operand_arrays[id]);
4405
4406 if (absl::c_linear_search(tiled_param_ids, id)) {
4407 param_shmem_buffers[id] = get_shared_memory_buffer(
4408 llvm_ir::PrimitiveTypeToIrType(param_shape.element_type(), module_),
4409 IrName(name, StrCat("tile", id)));
4410 VLOG(3) << "Added shmem buffer for parameter " << id << ": "
4411 << llvm_ir::DumpToString(*param_shmem_buffers[id]);
4412 Shape reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout(
4413 param_shape.element_type(), Permute(reduced_output_dims, {0, 2, 1}));
4414 param_in_reduced_shape_arrays.push_back(
4415 param_arrays[id].CastToShape(reduced_shape, &b_));
4416 } else {
4417 param_in_reduced_shape_arrays.push_back(IrArray());
4418 }
4419 }
4420
4421 EmitElementFunction element_generator =
4422 [&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc,
4423 llvm::Value* x_loc, int64_t x_iter_num) {
4424 auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op);
4425 EmitTileElementForFusion(fusion, operand_arrays, output_arrays, index,
4426 mapping_scheme, y_loc, x_loc,
4427 param_shmem_buffers);
4428 };
4429
4430 TileElementGenerator tile_generator =
4431 [&](const ThreadIdInfo& thread_id_info, const IrArray::Index& index,
4432 const string& loop_name, llvm::Value* tile_height,
4433 llvm::Value* tile_width, KernelSupportLibrary* ksl) {
4434 // If shared memory transpose is needed, wait for all threads to reach
4435 // this point, lest we copy a value from tile to output before the other
4436 // thread copies it from input to tile. This is `__syncthreads` in CUDA.
4437 if (!tiled_param_ids.empty()) {
4438 // Calculate the input tile origin from the output tile origin.
4439 const IrArray::Index input_tile_origin(
4440 Permute(index.multidim(), {0, 2, 1}),
4441 Permute(index.dims(), {0, 2, 1}), index.GetType());
4442
4443 // Copy input parameter values to shared memory buffers:
4444 // tile[thread_id_y, thread_id_x] = input[index]
4445 // Note that tile_width and tile_height are flipped here because we
4446 // are reading a transposed tile.
4447 EmitTile(mapping_scheme, input_tile_origin, "input", ksl,
4448 thread_id_info, tile_width, tile_height,
4449 [&](const IrArray::Index& index, llvm::Value* y_loc,
4450 llvm::Value* x_loc, int64_t /*x_iter_num*/) {
4451 for (int64_t id : tiled_param_ids) {
4452 IrArray& input_in_logical_shape =
4453 param_in_reduced_shape_arrays.at(id);
4454
4455 llvm::Value* shmem_buffer = param_shmem_buffers.at(id);
4456 llvm::Value* zero =
4457 llvm::ConstantInt::get(index_type, 0);
4458 // TODO(jlebar): Add AA metadata to this store. Tile
4459 // buffers are global variables, so LLVM can't infer much
4460 // about it.
4461 auto value = input_in_logical_shape.EmitReadArrayElement(
4462 index, &b_, "input_element");
4463 auto addr = GEP(shmem_buffer, {zero, y_loc, x_loc});
4464 Store(value, addr);
4465 }
4466 });
4467
4468 // Wait for all threads to reach this point using `__syncthreads` in
4469 // CUDA.
4470 EmitSyncThreads();
4471 }
4472
4473 EmitTile(mapping_scheme, index, loop_name, ksl, thread_id_info,
4474 tile_height, tile_width, element_generator);
4475 bool block_contains_multi_tiles = mapping_scheme.GetTileSizeZ() > 1;
4476
4477 // If a tile block contains multiple tiles and shared memory buffers are
4478 // used, we need to wait for all threads to finish using the shared
4479 // memory buffer for the current tile before we move on to process the
4480 // next tile and overwrite the shared memory buffers.
4481 if (block_contains_multi_tiles && !tiled_param_ids.empty()) {
4482 EmitSyncThreads();
4483 }
4484 };
4485
4486 EmitTilingKernel(mapping_scheme, index_type, tile_generator);
4487 }
4488
4489 namespace {
4490
4491 // A recursive function to inspect the users of a parameter to determine
4492 // whether it's safe for a parameter to participate in a shared-memory
4493 // transpose.
4494 //
4495 // Consider a fusion parameter P for which we might want to use a shmem
4496 // transpose. If we do, we use a GPU thread block to preload a tile of P with
4497 // indices [z, y..y+31, x..x+31] to compute an output tile with the same indices
4498 // cooperatively, where z, y, x are the indices for the normalized input/output
4499 // tensor (see the document for FindTranspose021 for the definition of
4500 // normalized tensor for 0-2-1 transpose). This shmem transpose implementation
4501 // requires that the computation of the output tile only read elements within
4502 // the preload tile. If this is not true, we can't use a shmem transpose for P.
4503 //
4504 // If the computation of output element [z, y, x] only requires the element of
4505 // P with the same indices, the shmem transpose implementation can be applied
4506 // to P safely. This is a sufficient but not necessary condition. We check all
4507 // the transitive users of P to see if we can find a user that may cause an
4508 // exception to the situation. If such a user is not found, we conclude that P
4509 // is safe for shmem transpose.
4510 //
4511 // This is trivially true for elementwise operations and some "data-movement"
4512 // ops like kTuple. However, it's not true for operations that can change the
4513 // dimensions of the inputs (e.g. pad, slice) and bitcast operation.
4514 // For example:
4515 //
4516 // fused_computation {
4517 // param_0 = f32[64,64]{1,0} parameter(0)
4518 // ROOT bitcast = f32[64,64]{0,1} bitcast(param_0)
4519 // }
4520 // The output element at logical address [0, 63] depends on the input element
4521 // at logical address [63, 0], which would not be within the shared-memory
4522 // block.
4523 //
4524 // TODO(bixia): In order to extend this for kInput fusion, that is reduction
4525 // with transpose, we only need to end the use-chain checking with the input of
4526 // a reduce operations. In this case, the above description on "output" apply
4527 // to the result of such a use-chain, which provides the input to the reduce
4528 // operation.
IsInstructionSafeForShmemTranspose(mlir::Operation * op)4529 bool IsInstructionSafeForShmemTranspose(mlir::Operation* op) {
4530 if (mlir::isa<mlir::memref::TensorStoreOp>(op)) {
4531 return true;
4532 }
4533
4534 HloOpcode opcode;
4535 if (mlir::isa<mlir::memref::TensorLoadOp>(op)) {
4536 opcode = HloOpcode::kParameter;
4537 } else {
4538 opcode = *MhloToHloOpcode(op);
4539 }
4540 if (HloInstruction::IsOpElementwise(opcode)) {
4541 for (mlir::Value v : op->getResults()) {
4542 for (mlir::OpOperand use : v.getUsers()) {
4543 if (!IsInstructionSafeForShmemTranspose(use.getOwner())) {
4544 return false;
4545 }
4546 }
4547 }
4548 return true;
4549 }
4550
4551 switch (opcode) {
4552 // Non-elementwise instructions that don't cause the shmem transpose
4553 // to be unsafe, including the instructions that don't currently fuse.
4554 case HloOpcode::kGetDimensionSize:
4555 // The result of the operation doesn't rely on the content of the
4556 // tensor. As such, there is no need to further inspect its users.
4557 return true;
4558 case HloOpcode::kGetTupleElement:
4559 case HloOpcode::kMap:
4560 case HloOpcode::kParameter:
4561 case HloOpcode::kTuple:
4562 case HloOpcode::kTupleSelect:
4563 for (mlir::Value v : op->getResults()) {
4564 for (mlir::OpOperand use : v.getUsers()) {
4565 if (!IsInstructionSafeForShmemTranspose(use.getOwner())) {
4566 return false;
4567 }
4568 }
4569 }
4570 return true;
4571
4572 default:
4573 return false;
4574 }
4575 }
4576
4577 // Given a group of input parameters that are 0-2-1 transpose of the outputs of
4578 // a fusion kernel, returns the input parameters that are safe for the shared
4579 // memory transpose implementation.
4580 //
4581 // When a tile based shared memory transpose is used to implement an input with
4582 // 0-2-1 transpose, we preload a tile of the input elements
4583 // [z, y..y+31, x..x+31] to compute the output tile elements of the same
4584 // indices. Preloading the input tile this way is only safe when the computation
4585 // of the output tile elements do not need any input element outside the
4586 // preloaded tile. We inspect all the transitive users of the input parameter
4587 // up to the fusion root instruction to see if we can find any instruction
4588 // that can make preloading the input tile unsafe.
FilterInputsForShmemTranspose(mlir::lmhlo::FusionOp fusion,std::vector<int64> input_ids)4589 std::vector<int64> FilterInputsForShmemTranspose(mlir::lmhlo::FusionOp fusion,
4590 std::vector<int64> input_ids) {
4591 std::vector<mlir::Value> params = ToStdVector(fusion.getFusionParameters());
4592
4593 std::vector<int64> filtered_input_ids;
4594 for (int64_t input_id : input_ids) {
4595 mlir::Value input = params.at(input_id);
4596 if (IsInstructionSafeForShmemTranspose(input.getDefiningOp())) {
4597 filtered_input_ids.push_back(input_id);
4598 }
4599 }
4600 return filtered_input_ids;
4601 }
4602
4603 } // namespace
4604
CheckAndEmitHloWithTile021(mlir::Operation * op)4605 StatusOr<bool> IrEmitterUnnested::CheckAndEmitHloWithTile021(
4606 mlir::Operation* op) {
4607 CHECK(mlir::isa<mlir::lmhlo::FusionOp>(op));
4608
4609 MlirEmitterContext context;
4610 context.SetOperation(op);
4611
4612 // If the output_shape is reduced to 021 shape, find all the parameters of
4613 // the HLO that are in the corresponding 012 shape.
4614 std::vector<int64> params_012;
4615 optional<std::vector<int64>> reduced_dims_021;
4616 for (int64_t operand_idx = 0; operand_idx < context.operand_shapes.size();
4617 ++operand_idx) {
4618 const Shape& operand_shape = context.operand_shapes[operand_idx];
4619 auto find_transpose_result =
4620 ShapeUtil::FindTranspose021(operand_shape, context.output_shapes[0]);
4621 if (!find_transpose_result.has_value()) {
4622 continue;
4623 }
4624 const std::vector<int64>& curr_reduced_dims_021 = *find_transpose_result;
4625 if (!reduced_dims_021.has_value()) {
4626 reduced_dims_021 = curr_reduced_dims_021;
4627 }
4628 if (!absl::c_equal(*reduced_dims_021, curr_reduced_dims_021)) {
4629 // There is more than one possible transpose. Instead of picking one
4630 // transpose, we simply give up here.
4631 return false;
4632 }
4633 params_012.push_back(operand_idx);
4634 }
4635
4636 if (!reduced_dims_021.has_value()) {
4637 return false;
4638 }
4639
4640 if ((*reduced_dims_021)[1] < kMinDimensionToTransposeTiled ||
4641 (*reduced_dims_021)[2] < kMinDimensionToTransposeTiled) {
4642 return false;
4643 }
4644
4645 if (auto fusion_op = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
4646 params_012 = FilterInputsForShmemTranspose(fusion_op, params_012);
4647 if (params_012.empty()) {
4648 return false;
4649 }
4650 }
4651
4652 // Each of our shared memory tiles has 32*33 elements (so ~4kb, if the
4653 // elements are of size 4 bytes), and CUDA has an architectural limit of
4654 // 48kb shared memory per SM. (This is increased to 96kb in Volta, but we
4655 // don't use this, in part because it eats into our L1 cache space.)
4656 //
4657 // For correctness we need to ensure that we don't make more than 48kb worth
4658 // of shmem tiles per block. And for performance, we'd probably like to use
4659 // significantly less, so that we can fit more than one block at a time on a
4660 // gpu core.
4661 //
4662 // We say without benchmarks that we want at least 3 threads/block,
4663 // corresponding to 3 shmem tiles if the elements are 32 bits wide. We
4664 // choose which params get the shmem transpose treatment arbitrarily; it's
4665 // not clear if there's a Right Choice.
4666 //
4667 // This is only sound if tiled transposes are the only place where we use
4668 // shared memory in fusions. If in the future other fusible ops use shared
4669 // memory, we'll have to adjust this heuristic.
4670 constexpr int kMinBlocksPerCore = 3;
4671 constexpr int64_t kShmemPerCore = 48 * 1024;
4672 int64_t shmem_used = 0;
4673 for (int64_t i = 0; i < params_012.size(); ++i) {
4674 const Shape& operand_shape = context.operand_shapes[params_012[i]];
4675 shmem_used +=
4676 32 * 33 *
4677 ShapeUtil::ByteSizeOfPrimitiveType(operand_shape.element_type());
4678
4679 if (kMinBlocksPerCore * shmem_used > kShmemPerCore) {
4680 // Erase this element and everything after it from params_012.
4681 params_012.resize(i);
4682 break;
4683 }
4684 }
4685
4686 if (params_012.empty()) {
4687 return false;
4688 }
4689
4690 constexpr int kNumRows = 4;
4691 KernelMappingScheme mapping_scheme(*reduced_dims_021,
4692 /*tile_sizes=*/{1, kWarpSize, kWarpSize},
4693 /*num_threads_y=*/kNumRows,
4694 /*num_threads_x=*/kWarpSize,
4695 /*indexing_order=*/kLinearIndexingX,
4696 /*vector_size=*/1,
4697 /*is_row_contiguous=*/false);
4698 LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(),
4699 mapping_scheme.GetThreadsPerBlock());
4700 std::vector<llvm_ir::IrArray> ir_arrays;
4701 TF_ASSIGN_OR_RETURN(
4702 std::unique_ptr<KernelThunk> kernel_thunk,
4703 BuildKernelThunk(op, GetThunkInfo(op), &ir_arrays, launch_dimensions));
4704
4705 EmitHlo021Tile(
4706 op, kernel_thunk.get(), context,
4707 absl::MakeSpan(ir_arrays).subspan(0, context.operand_shapes.size()),
4708 absl::MakeSpan(ir_arrays).subspan(context.operand_shapes.size()),
4709 *reduced_dims_021, params_012, mapping_scheme, launch_dimensions);
4710 AddThunkToThunkSequence(std::move(kernel_thunk));
4711 return true;
4712 }
4713
4714 namespace {
4715
4716 // Returns true if all the transitive users of hlo before hitting users in
4717 // use_chain_endings are elementwise operations.
AreUsersElementwise(mlir::Value value,const absl::flat_hash_set<mlir::Operation * > & use_chain_endings)4718 bool AreUsersElementwise(
4719 mlir::Value value,
4720 const absl::flat_hash_set<mlir::Operation*>& use_chain_endings) {
4721 return absl::c_all_of(value.getUsers(), [&](mlir::OpOperand use) {
4722 mlir::Operation* user = use.getOwner();
4723 CHECK_EQ(1, user->getNumResults());
4724 return use_chain_endings.count(user) ||
4725 (HloInstruction::IsOpElementwise(*MhloToHloOpcode(user)) &&
4726 AreUsersElementwise(user->getResult(0), use_chain_endings));
4727 });
4728 }
4729
4730 // Returns the number of fusion inputs that have the same dimension as the
4731 // given shape, and involve in only elementwise operations.
NumInputsInvolveInOnlyElementwiseOps(mlir::lmhlo::FusionOp fusion,const Shape & op_shape,const absl::flat_hash_set<mlir::Operation * > & use_chain_endings)4732 int64 NumInputsInvolveInOnlyElementwiseOps(
4733 mlir::lmhlo::FusionOp fusion, const Shape& op_shape,
4734 const absl::flat_hash_set<mlir::Operation*>& use_chain_endings) {
4735 return absl::c_count_if(
4736 fusion.getFusionParameters(), [&](mlir::Value parameter) {
4737 Shape parameter_shape = GetShape(parameter);
4738 return ShapeUtil::SameDimensions(op_shape, parameter_shape) &&
4739 AreUsersElementwise(parameter, use_chain_endings);
4740 });
4741 }
4742
4743 // Returns the number of fusion inputs that have more elements than the given
4744 // shape.
NumInputsWithMoreElementsThan(mlir::lmhlo::FusionOp fusion,const Shape & shape)4745 int64 NumInputsWithMoreElementsThan(mlir::lmhlo::FusionOp fusion,
4746 const Shape& shape) {
4747 int64_t num_elements = ShapeUtil::ElementsIn(shape);
4748 return absl::c_count_if(
4749 fusion.getFusionParameters(), [&](mlir::Value parameter) {
4750 Shape parameter_shape = GetShape(parameter);
4751 return ShapeUtil::ElementsIn(parameter_shape) > num_elements;
4752 });
4753 }
4754
4755 // The benefit of unrolling a kInput fusion that is a column reduction comes
4756 // from the vectorization of non-reduction fusion outputs and fusion inputs.
4757 // On the other hand, unrolling can also introduce factors that can cause
4758 // the kernel to run slower. This routine uses a simple heuristic to estimate
4759 // the benefit as well as the overhead of unrolling in order to decide whether
4760 // unrolling is beneficial for the given kInput fusion.
IsUnrollingColumnReductionBeneficial(mlir::Operation * unnested_hlo,const Shape & input_shape,int64_t num_kept_minor,const FusionLayoutAnalysis & layout_analysis)4761 bool IsUnrollingColumnReductionBeneficial(
4762 mlir::Operation* unnested_hlo, const Shape& input_shape,
4763 int64_t num_kept_minor, const FusionLayoutAnalysis& layout_analysis) {
4764 // TODO(b/122468062): Need further investigate to see whether we can
4765 // remove the constraint on IsPowerOfTwo.
4766 if (!IsPowerOfTwo(static_cast<uint64>(num_kept_minor))) {
4767 return false;
4768 }
4769
4770 if (IsReductionFromOrToContiguousDimensions(unnested_hlo, layout_analysis)) {
4771 return true;
4772 }
4773
4774 auto fusion = mlir::cast<mlir::lmhlo::FusionOp>(unnested_hlo);
4775 int64_t can_be_vectorized = 0;
4776 int64_t cannot_be_vectorized = 0;
4777 auto fusion_results = ToStdVector(fusion.getFusionResults());
4778 absl::flat_hash_set<mlir::Operation*> use_chain_endings;
4779 if (fusion_results.size() == 1) {
4780 if (IsReductionFromOrToContiguousDimensions(
4781 fusion_results[0].getDefiningOp(), layout_analysis)) {
4782 use_chain_endings.insert(fusion_results[0].getDefiningOp());
4783 // Atomic.add of the reduction result can't be vectorized.
4784 cannot_be_vectorized++;
4785 }
4786 } else {
4787 for (mlir::Value result : fusion_results) {
4788 if (IsReductionFromOrToContiguousDimensions(result.getDefiningOp(),
4789 layout_analysis)) {
4790 // Atomic.add of the reduction result can't be vectorized.
4791 cannot_be_vectorized++;
4792 } else {
4793 // Write of the non-reduction result can be vectorized.
4794 can_be_vectorized++;
4795 }
4796 use_chain_endings.insert(result.getDefiningOp());
4797 }
4798 }
4799 // Fusion inputs that have the same dimension as the reduce input and
4800 // only involve in elementwise operations can be vectorized.
4801 can_be_vectorized += NumInputsInvolveInOnlyElementwiseOps(fusion, input_shape,
4802 use_chain_endings);
4803 // Fusion inputs with more elements than the reduce op input must participate
4804 // in non-elementwise operations and we assume that they are not vectorizable
4805 // for the purpose of estimating the benefit of unrolling. If the kernel is
4806 // unrolled even with such an assumption, and the accesses to those inputs
4807 // turn out to be vectorizable, the compiler will still vectorize them.
4808 cannot_be_vectorized += NumInputsWithMoreElementsThan(fusion, input_shape);
4809 return can_be_vectorized >= cannot_be_vectorized;
4810 }
4811
NearestPowerOfTwo(int64_t v)4812 int64 NearestPowerOfTwo(int64_t v) {
4813 if (v < 0) {
4814 return 0;
4815 }
4816 int64_t upper = tensorflow::NextPowerOfTwo64(v);
4817 int64_t lower = upper >> 1;
4818 return upper - v < v - lower ? upper : lower;
4819 }
4820
4821 } // namespace
4822
ComputeReductionCodegenInfo(mlir::Operation * unnested_hlo,mlir::Operation * first_reduce,const FusionLayoutAnalysis & layout_analysis)4823 ReductionCodegenInfo IrEmitterUnnested::ComputeReductionCodegenInfo(
4824 mlir::Operation* unnested_hlo, mlir::Operation* first_reduce,
4825 const FusionLayoutAnalysis& layout_analysis) {
4826 Shape input_shape = layout_analysis.GetShape(first_reduce->getOperand(0));
4827 ReductionDimensions reduction_dimensions =
4828 GetReductionKindAndContiguousComponents(first_reduce);
4829 VLOG(10) << "is_row_reduction " << reduction_dimensions.is_row_reduction
4830 << " " << reduction_dimensions.dimensions[0] << " "
4831 << reduction_dimensions.dimensions[1] << " "
4832 << reduction_dimensions.dimensions[2];
4833 auto get_dtype_bits = [](mlir::Value i) {
4834 // TODO(timshen): may not be efficient.
4835 return primitive_util::BitWidth(GetShape(i).element_type());
4836 };
4837
4838 // For fusion with multiple inputs, use the smallest input dtype to
4839 // select the reduction_tiling.
4840 int smallest_input_dtype_bits = get_dtype_bits(first_reduce->getOperand(0));
4841
4842 for (mlir::Value operand : GetHloOperands(unnested_hlo)) {
4843 smallest_input_dtype_bits =
4844 std::min(get_dtype_bits(operand), smallest_input_dtype_bits);
4845 }
4846 std::array<int64, 3> reduction_tiling =
4847 GetReductionTiling(reduction_dimensions, smallest_input_dtype_bits,
4848 ir_emitter_context_->cuda_compute_capability());
4849
4850 int64_t num_threads_y = reduction_dimensions.is_row_reduction ? 1 : kWarpSize;
4851 int64_t num_threads_x = [&] {
4852 if (reduction_dimensions.is_row_reduction) {
4853 // Use 512 as default block size (threads per block) for row reductions.
4854 // For multi-output fusions, reduce the block size further to decrease
4855 // register pressure when multiple outputs are computed by each thread.
4856 int64_t fan_out = 1;
4857 if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(unnested_hlo)) {
4858 fan_out = fusion.getFusionResults().size();
4859 }
4860
4861 int64_t max_block_size =
4862 std::max(kMinThreadsXRowReduction,
4863 static_cast<int64_t>(512LL / NearestPowerOfTwo(fan_out)));
4864 return std::min(
4865 max_block_size,
4866 RoundUpToNearest(CeilOfRatio(reduction_dimensions.dimensions[2],
4867 reduction_tiling[2]),
4868 kWarpSize));
4869 }
4870 return kWarpSize;
4871 }();
4872
4873 bool tile_fit = reduction_dimensions.dimensions[kDimX] %
4874 (reduction_tiling[2] * num_threads_x) ==
4875 0;
4876 se::CudaComputeCapability cc = ir_emitter_context_->cuda_compute_capability();
4877
4878 int num_partial_results = 1;
4879 KernelMappingScheme::IndexingOrder indexing_order = [&]() {
4880 if (reduction_dimensions.is_row_reduction &&
4881 // P100, only try to vectorize+coales memory access when the
4882 // tile size fits exactly and dtypes <= 32 bits
4883 ((cc.major == 6 && smallest_input_dtype_bits <= 32 && tile_fit) ||
4884 // On V100, only try to vectorize+coales memory access for
4885 // rows of even size. For odd row sizes, every other row
4886 // isn't aligned, so it can't be vectorized.
4887 (cc.major >= 7 && reduction_dimensions.dimensions[2] % 2 == 0))) {
4888 return kStridedLinearIndexingX;
4889 } else if (!reduction_dimensions.is_row_reduction &&
4890 IsUnrollingColumnReductionBeneficial(
4891 unnested_hlo, input_shape,
4892 reduction_dimensions.dimensions[2], layout_analysis)) {
4893 num_partial_results = 2;
4894 reduction_tiling[2] *= num_partial_results;
4895 return kLinearIndexingX;
4896 } else {
4897 return kStridedIndexingX;
4898 }
4899 }();
4900
4901 int vector_size = 1;
4902 if (indexing_order == kStridedLinearIndexingX) {
4903 // Assuming XLA will perform the unrolling and LLVM will vectorize,
4904 // disable the unroll for the cases that LLVM doesn't vectorize.
4905 if (reduction_dimensions.dimensions[2] % 2 == 0 &&
4906 !MayPreventVectorization(unnested_hlo)) {
4907 vector_size = 2;
4908 } else {
4909 indexing_order = kStridedIndexingX;
4910 }
4911 }
4912 KernelMappingScheme mapping_scheme(
4913 reduction_dimensions.dimensions,
4914 {reduction_tiling[0], reduction_tiling[1] * num_threads_y,
4915 reduction_tiling[2] * num_threads_x},
4916 num_threads_y, num_threads_x, indexing_order, vector_size);
4917 return ReductionCodegenInfo(
4918 mapping_scheme, num_partial_results,
4919 reduction_dimensions.is_row_reduction,
4920 ReductionIsRaceFree(reduction_dimensions, reduction_tiling));
4921 }
4922
EmitIRForReduction(mlir::Operation * unnested_hlo,absl::Span<const int> instr_index_group,HloComputation * fused_computation,FusedIrEmitter * fused_emitter,absl::Span<const llvm_ir::IrArray> result_ir_arrays,ReductionCodegenState * reduction_info,const Shape & input_shape,const FusionLayoutAnalysis & layout_analysis)4923 void IrEmitterUnnested::EmitIRForReduction(
4924 mlir::Operation* unnested_hlo, absl::Span<const int> instr_index_group,
4925 HloComputation* fused_computation, FusedIrEmitter* fused_emitter,
4926 absl::Span<const llvm_ir::IrArray> result_ir_arrays,
4927 ReductionCodegenState* reduction_info, const Shape& input_shape,
4928 const FusionLayoutAnalysis& layout_analysis) {
4929 std::vector<HloComputation*> reducers;
4930 for (int index : instr_index_group) {
4931 mlir::Operation* reduce = GetReduceFromUnnestedMlir(unnested_hlo, index);
4932 if (!IsReductionFromOrToContiguousDimensions(reduce, layout_analysis)) {
4933 continue;
4934 }
4935 if (auto nested_reduce = mlir::dyn_cast<mlir::mhlo::ReduceOp>(reduce)) {
4936 HloInstruction* root = fused_computation->root_instruction();
4937 if (root->opcode() == HloOpcode::kTuple) {
4938 root = root->mutable_operand(index);
4939 } else {
4940 CHECK_EQ(0, index);
4941 }
4942 reducers.push_back(root->to_apply());
4943 } else {
4944 LOG(FATAL) << "Unexpected reduce op: " << MlirToString(reduce);
4945 }
4946 }
4947 CHECK(!reducers.empty()) << " expect at least one reduce instructions.";
4948
4949 const KernelMappingScheme& mapping_scheme =
4950 reduction_info->GetKernelMappingScheme();
4951 LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(),
4952 mapping_scheme.GetThreadsPerBlock());
4953 llvm::Type* index_ty = GetIndexTypeForKernel(
4954 unnested_hlo, launch_dimensions.launch_bound(), &b_);
4955 EmitPrologueForReduction(unnested_hlo, instr_index_group, fused_computation,
4956 fused_emitter, result_ir_arrays, reduction_info,
4957 layout_analysis);
4958
4959 EmitElementFunction emit_reduction_tile =
4960 [&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc,
4961 llvm::Value* x_loc, int64_t x_iter_num) {
4962 EmitTileElementForReduction(
4963 unnested_hlo, input_shape, instr_index_group, fused_computation,
4964 fused_emitter, result_ir_arrays, reducers, index, *reduction_info,
4965 x_iter_num, layout_analysis);
4966 };
4967
4968 TilingKernelInfo tiling_kernel_info = EmitTilingKernel(
4969 mapping_scheme, index_ty,
4970 [&](const ThreadIdInfo& thread_id_info, const IrArray::Index& index,
4971 const string& loop_name, llvm::Value* tile_height,
4972 llvm::Value* tile_width, KernelSupportLibrary* ksl) {
4973 EmitTile(reduction_info->GetKernelMappingScheme(), index, loop_name,
4974 ksl, thread_id_info, tile_height, tile_width,
4975 emit_reduction_tile);
4976 });
4977 EmitEpilogueForReduction(index_ty, unnested_hlo, instr_index_group,
4978 result_ir_arrays, reducers, *reduction_info,
4979 tiling_kernel_info, layout_analysis);
4980 }
4981
4982 namespace {
4983
4984 // Returns whether the `instr` is either a constant, a scalar, or a
4985 // broadcasted constant/scalar.
IsBroadcastedConstantOrScalar(const HloInstruction & instr)4986 bool IsBroadcastedConstantOrScalar(const HloInstruction& instr) {
4987 return instr.IsConstant() || ShapeUtil::IsScalar(instr.shape()) ||
4988 (HloOpcode::kBroadcast == instr.opcode() &&
4989 (instr.operand(0)->IsConstant() ||
4990 ShapeUtil::IsScalar(instr.operand(0)->shape())));
4991 }
4992
4993 // Divides `num_reduces` reduces into groups. Different groups will be executed
4994 // in parallel. Generally speaking, we'd like to run the reduce instructions
4995 // in parallel without incurring too much recomputation overhead. The current
4996 // heuristic is to place reduce instructions who share nothing or only
4997 // (broadcasted) scalars/constants into different groups; otherwise, they are
4998 // placed in the same group. Non-reduce instructions always go with the reduce
4999 // instructions into the same group so long as they share any predecessors.
GroupDisjointReductions(HloComputation * fused_computation,int num_reduces)5000 std::vector<std::vector<int>> GroupDisjointReductions(
5001 HloComputation* fused_computation, int num_reduces) {
5002 CHECK_NE(0, num_reduces);
5003 if (num_reduces == 1) {
5004 return {{0}};
5005 }
5006
5007 std::vector<tensorflow::UnionFind<HloInstruction*>> disjoint_sets(
5008 num_reduces);
5009 for (size_t i = 0; i < num_reduces; ++i) {
5010 disjoint_sets[i].Get() =
5011 fused_computation->root_instruction()->mutable_operand(i);
5012 }
5013
5014 std::unique_ptr<HloReachabilityMap> reachability_map =
5015 HloReachabilityMap::Build(fused_computation);
5016 for (HloInstruction* instr : fused_computation->instructions()) {
5017 std::vector<int64> reached_output_ids;
5018 for (size_t oid = 0; oid < num_reduces; ++oid) {
5019 HloInstruction* reduce =
5020 fused_computation->root_instruction()->mutable_operand(oid);
5021 if (HloOpcode::kReduce == reduce->opcode() &&
5022 (IsBroadcastedConstantOrScalar(*instr))) {
5023 // Do not group output reduce instructions through broadcasted
5024 // constants or scalars, as the recomputation should be acceptable.
5025 VLOG(3) << "Skip broadcasted constant or scalar " << instr->ToString();
5026 continue;
5027 }
5028 // Now group output instructions if they have common predecessors.
5029 if (reachability_map->IsReachable(instr, reduce)) {
5030 VLOG(3) << "Reaching " << reduce->ToString() << " from "
5031 << instr->ToString();
5032 reached_output_ids.push_back(oid);
5033 }
5034 }
5035 for (size_t j = 1; j < reached_output_ids.size(); ++j) {
5036 disjoint_sets[reached_output_ids[0]].Merge(
5037 &disjoint_sets[reached_output_ids[j]]);
5038 }
5039 }
5040 // Place output instructions in the same set into the same group.
5041 HloInstructionMap<std::vector<int>> groups;
5042 for (size_t oid = 0; oid < num_reduces; ++oid) {
5043 groups[disjoint_sets[oid].Get()].push_back(oid);
5044 }
5045
5046 std::vector<std::vector<int>> ret;
5047 absl::c_for_each(
5048 groups, [&](auto& iter) { ret.emplace_back(std::move(iter.second)); });
5049 return ret;
5050 }
5051
5052 } // namespace
5053
EmitUnnestedReduction(mlir::Operation * unnested_hlo,const FusionLayoutAnalysis & layout_analysis)5054 Status IrEmitterUnnested::EmitUnnestedReduction(
5055 mlir::Operation* unnested_hlo,
5056 const FusionLayoutAnalysis& layout_analysis) {
5057 auto fusion = mlir::cast<mlir::lmhlo::FusionOp>(unnested_hlo);
5058
5059 int num_reduces = fusion.getFusionResults().size();
5060
5061 // Build a kernel thunk to compute all the outputs.
5062 mlir::Operation* first_reduce = nullptr;
5063 for (int i = 0; i < num_reduces; ++i) {
5064 mlir::Operation* output_instruction =
5065 GetReduceFromUnnestedMlir(unnested_hlo, i);
5066 if (IsReductionFromOrToContiguousDimensions(output_instruction,
5067 layout_analysis)) {
5068 first_reduce = GetReduceFromUnnestedMlir(unnested_hlo, i);
5069 break;
5070 }
5071 }
5072 CHECK(first_reduce) << MlirToString(unnested_hlo);
5073 if (num_reduces > 1) {
5074 for (int i = 0; i < num_reduces; i++) {
5075 auto candidate = mlir::dyn_cast<mlir::mhlo::ReduceOp>(
5076 GetReduceFromUnnestedMlir(unnested_hlo, i));
5077 if (candidate &&
5078 !IsFusedReductionOutputConsistent(
5079 candidate, mlir::cast<mlir::mhlo::ReduceOp>(first_reduce),
5080 layout_analysis)) {
5081 return InternalError("Inconsistent reduction fusion outputs");
5082 }
5083 }
5084 }
5085 Shape input_shape = GetShape(first_reduce->getOperand(0));
5086 // The layout of a reduction input is either set by LayoutAssignment for
5087 // unnested kReduce or by InstructionFusion for fused kReduce.
5088 CHECK(input_shape.has_layout()) << "LayoutAssignment or InstructionFusion "
5089 "doesn't set the input layout of "
5090 << MlirToString(first_reduce);
5091
5092 HloComputation* fused_computation = nullptr;
5093 TF_ASSIGN_OR_RETURN(fused_computation,
5094 GetOrCreateSubComputationFromRegion(&fusion.region(),
5095 /*is_fusion=*/true));
5096
5097 // Group disjoint reductions in groups, to be executed in parallel.
5098 std::vector<std::vector<int>> instr_index_groups =
5099 GroupDisjointReductions(fused_computation, num_reduces);
5100
5101 VLOG(2) << StrCat("Generate in ", instr_index_groups.size(), " groups for ",
5102 MlirToString(unnested_hlo));
5103
5104 ReductionCodegenInfo reduction_info =
5105 ComputeReductionCodegenInfo(unnested_hlo, first_reduce, layout_analysis);
5106 const KernelMappingScheme& mapping_scheme =
5107 reduction_info.GetKernelMappingScheme();
5108
5109 // block_y_count is set to instr_index_groups.size(), so that each reduction
5110 // group can be run in parallel by a different BlockIdy.
5111 LaunchDimensions launch_dimensions(
5112 {/*x=*/mapping_scheme.GetNumberOfBlocks(),
5113 /*y=*/static_cast<int64>(instr_index_groups.size()),
5114 /*z=*/1},
5115 {/*x=*/mapping_scheme.GetThreadsPerBlock(), /*y=*/1, /*z=*/1});
5116 VLOG(3) << "Launch dimensions of "
5117 << mlir::GetNameFromLoc(unnested_hlo->getLoc())
5118 << ": number of blocks: " << mapping_scheme.GetNumberOfBlocks()
5119 << " - threads per block: " << mapping_scheme.GetThreadsPerBlock();
5120
5121 std::vector<llvm_ir::IrArray> ir_arrays;
5122 TF_ASSIGN_OR_RETURN(std::unique_ptr<KernelThunk> kernel_thunk,
5123 BuildKernelThunk(unnested_hlo, Thunk::ThunkInfo(),
5124 &ir_arrays, launch_dimensions));
5125
5126 GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
5127 ir_emitter_context_->llvm_module(),
5128 &b_, GetNestedComputer());
5129 FusedIrEmitter fused_emitter(&elemental_emitter);
5130
5131 CHECK_LT(fused_computation->num_parameters(), ir_arrays.size());
5132 for (int i = 0; i < fused_computation->num_parameters(); i++) {
5133 llvm_ir::IrArray ir_array = ir_arrays[i];
5134 HloInstruction* fused_operand = fused_computation->parameter_instruction(i);
5135 fused_emitter.BindGenerator(
5136 fused_operand,
5137 [this, ir_array, fused_operand](const llvm_ir::IrArray::Index& index) {
5138 return ir_array.EmitReadArrayElement(index, &b_,
5139 fused_operand->name());
5140 });
5141 }
5142 absl::Span<const llvm_ir::IrArray> result_ir_arrays =
5143 absl::MakeSpan(ir_arrays).subspan(fused_computation->num_parameters(),
5144 num_reduces);
5145
5146 // We always use the first reduce as representative to construct
5147 // ReductionCodegenInfo, since all the reductions are required to have the
5148 // same shape and layout as verified by `IsFusedReductionOutputConsistent()`.
5149 ReductionCodegenInfo reduction_codegen_info =
5150 ComputeReductionCodegenInfo(unnested_hlo, first_reduce, layout_analysis);
5151
5152 KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
5153 for (size_t i = 0; i < instr_index_groups.size(); ++i) {
5154 // Create a new ReductionCodegenInfo instance as it contains states for
5155 // code generation per reduction group.
5156 ReductionCodegenState reduction_info =
5157 ReductionCodegenState(reduction_codegen_info);
5158 // Use raw block_id_y to select the i-th parallel reduction to run. Using
5159 // block_id_y instead of block_id_x simplifies the index calculation
5160 // for reduction code generation as the block_id_y is orthogonal to
5161 // the indices used within the reductions.
5162 llvm::CallInst* raw_block_id_y = gpu::EmitCallToTargetIntrinsic(
5163 gpu::TargetIntrinsicID::kBlockIdy, {}, {}, &b_);
5164 llvm_ir::AddRangeMetadata(0, instr_index_groups.size(),
5165 llvm::cast<llvm::Instruction>(raw_block_id_y));
5166 ksl.If(StrCat("reduce-group-", i),
5167 b_.CreateICmpEQ(raw_block_id_y, b_.getInt32(i)), [&] {
5168 EmitIRForReduction(unnested_hlo, instr_index_groups[i],
5169 fused_computation, &fused_emitter,
5170 result_ir_arrays, &reduction_info, input_shape,
5171 layout_analysis);
5172 });
5173 }
5174
5175 if (hlo_module_config_.debug_options().xla_gpu_deterministic_reductions() &&
5176 !reduction_codegen_info.IsRaceFree()) {
5177 return InternalError(
5178 "All reductions should be race-free if deterministic reductions are "
5179 "enabled");
5180 }
5181
5182 // Build an initializer thunk to initialize each reduction output.
5183 ThunkSequence thunks;
5184 for (int i = 0; i < num_reduces; ++i) {
5185 mlir::Operation* output_instruction =
5186 GetReduceFromUnnestedMlir(unnested_hlo, i);
5187 if (!IsReductionFromOrToContiguousDimensions(output_instruction,
5188 layout_analysis)) {
5189 // Elemental IR emitter is used.
5190 continue;
5191 } else if (reduction_codegen_info.IsRaceFree()) {
5192 VLOG(5) << "We do not need initialization: using tree reductions";
5193 continue;
5194 }
5195
5196 TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> initializer_thunk,
5197 BuildFusedInitializerThunk(fusion, i));
5198 thunks.push_back(std::move(initializer_thunk));
5199 }
5200
5201 thunks.push_back(std::move(kernel_thunk));
5202 auto sequential_thunk = absl::make_unique<SequentialThunk>(
5203 GetThunkInfo(unnested_hlo), std::move(thunks));
5204 AddThunkToThunkSequence(std::move(sequential_thunk));
5205
5206 return Status::OK();
5207 }
5208
5209 // Emits code for slices based on the below structure. An if statement with
5210 // a guarding condition is generated for each ROOT slice.
5211 //
5212 // Pseudo code:
5213 //
5214 // Compute values of slice input operands
5215 //
5216 // Compute guarding_cond0
5217 // if (guarding_cond0) {
5218 // Write to output of slice0
5219 // }
5220 //
5221 // Compute guarding_cond1
5222 // if (guarding_cond1) {
5223 // Write to output of slice1
5224 // }
5225 //
EmitElementForInputFusibleSlices(const HloComputation * fused_computation,absl::Span<const llvm_ir::IrArray> ir_arrays,const llvm_ir::IrArray::Index & index)5226 Status IrEmitterUnnested::EmitElementForInputFusibleSlices(
5227 const HloComputation* fused_computation,
5228 absl::Span<const llvm_ir::IrArray> ir_arrays,
5229 const llvm_ir::IrArray::Index& index) {
5230 VLOG(10) << "Emitting slice input fusion for "
5231 << fused_computation->ToString();
5232
5233 HloInstruction* slice_or_tuple = fused_computation->root_instruction();
5234 auto slice_instructions = [&]() -> absl::Span<HloInstruction* const> {
5235 if (slice_or_tuple->opcode() == HloOpcode::kSlice) {
5236 return absl::Span<HloInstruction* const>(&slice_or_tuple, 1);
5237 }
5238 CHECK_EQ(slice_or_tuple->opcode(), HloOpcode::kTuple);
5239 return slice_or_tuple->operands();
5240 }();
5241
5242 // Emit input operand values of slices.
5243 std::vector<llvm::Value*> input_ir_values;
5244 GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
5245 GetNestedComputer());
5246 FusedIrEmitter fused_emitter(&elem_emitter);
5247 for (int i = 0; i < fused_computation->num_parameters(); i++) {
5248 fused_emitter.BindGenerator(
5249 fused_computation->parameter_instruction(i),
5250 [this, &ir_arrays, i](llvm_ir::IrArray::Index index) {
5251 return ir_arrays[i].EmitReadArrayElement(index, &b_);
5252 });
5253 }
5254 for (const HloInstruction* slice : slice_instructions) {
5255 auto input_generator = *fused_emitter.GetGenerator(slice->operand(0));
5256 input_ir_values.push_back(input_generator(index).ValueOrDie());
5257 }
5258
5259 // Emit for slice_instructions.
5260 KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
5261 for (int64_t i = 0; i < slice_instructions.size(); ++i) {
5262 HloInstruction* slice = slice_instructions[i];
5263
5264 // guarding_cond := index >= start && index < limit, for each dim.
5265 std::vector<llvm::Value*> index_within_ranges;
5266 for (size_t dim = 0; dim < slice->slice_starts().size(); ++dim) {
5267 CHECK_EQ(slice->slice_strides(dim), 1);
5268 auto larger_or_equal_than_start = b_.CreateICmpSGE(
5269 index.multidim()[dim],
5270 index.GetConstantWithIndexType(slice->slice_starts(dim)));
5271 llvm::Value* smaller_than_limit = b_.CreateICmpSLT(
5272 index.multidim()[dim],
5273 index.GetConstantWithIndexType(slice->slice_limits(dim)));
5274 llvm::Value* within_range =
5275 b_.CreateAnd(larger_or_equal_than_start, smaller_than_limit);
5276 index_within_ranges.push_back(within_range);
5277 }
5278 llvm::Value* guarding_cond = b_.CreateAnd(index_within_ranges);
5279
5280 auto emit_slice_elem_func = [&] {
5281 const std::vector<llvm::Value*>& src_multidim = index.multidim();
5282 std::vector<llvm::Value*> dst_multidim(src_multidim.size());
5283 for (size_t dim = 0; dim < src_multidim.size(); ++dim) {
5284 dst_multidim[dim] =
5285 Sub(src_multidim[dim],
5286 index.GetConstantWithIndexType(slice->slice_starts(dim)));
5287 }
5288 llvm_ir::IrArray src_ir_array =
5289 ir_arrays[fused_computation->num_parameters() + i];
5290 IrArray::Index slice_dst_index(dst_multidim, slice->shape(),
5291 index.GetType());
5292 src_ir_array.EmitWriteArrayElement(slice_dst_index, input_ir_values[i],
5293 &b_);
5294 };
5295
5296 ksl.If(StrCat("slice", i), guarding_cond, emit_slice_elem_func);
5297 }
5298 return Status::OK();
5299 }
5300
EmitInputFusibleNonStridedSlices(mlir::Operation * op)5301 Status IrEmitterUnnested::EmitInputFusibleNonStridedSlices(
5302 mlir::Operation* op) {
5303 auto fusion = mlir::cast<mlir::lmhlo::FusionOp>(op);
5304
5305 constexpr int unroll_factor = 1;
5306
5307 TF_ASSIGN_OR_RETURN(const HloComputation* fused_computation,
5308 GetOrCreateSubComputationFromRegion(&fusion.region(),
5309 /*is_fusion=*/true));
5310
5311 TF_ASSIGN_OR_RETURN(Shape element_shape,
5312 GetConsistentInputShapeForRootSlices(fused_computation));
5313 TF_ASSIGN_OR_RETURN(LaunchDimensions launch_dimensions,
5314 CalculateLaunchDimensions(
5315 element_shape, ir_emitter_context_->gpu_device_info(),
5316 {unroll_factor}));
5317
5318 std::vector<llvm_ir::IrArray> ir_arrays;
5319 TF_ASSIGN_OR_RETURN(auto kernel_thunk,
5320 BuildKernelThunk(fusion, GetThunkInfo(op), &ir_arrays,
5321 launch_dimensions));
5322
5323 Status emit_status =
5324 ParallelLoopEmitter(
5325 [&](const llvm_ir::IrArray::Index index) -> Status {
5326 return EmitElementForInputFusibleSlices(fused_computation,
5327 ir_arrays, index);
5328 },
5329 element_shape, launch_dimensions, &b_)
5330 .EmitLoop(IrName(mlir::GetNameFromLoc(fusion.getLoc())),
5331 GetIndexTypeForKernel(
5332 fusion, launch_dimensions.launch_bound(), &b_));
5333
5334 thunk_sequence_.emplace_back(std::move(kernel_thunk));
5335
5336 return emit_status;
5337 }
5338
EmitOp(mlir::Operation * op)5339 Status IrEmitterUnnested::EmitOp(mlir::Operation* op) {
5340 if (mlir::isa<mlir::ConstantOp, mlir::memref::ViewOp,
5341 mlir::memref::ReinterpretCastOp, mlir::ReturnOp,
5342 mlir::lmhlo::TerminatorOp>(op)) {
5343 return Status::OK();
5344 }
5345
5346 if (mlir::isa<mlir::memref::GetGlobalOp>(op)) {
5347 return EmitConstant(op);
5348 }
5349
5350 if (auto call = mlir::dyn_cast<mlir::lmhlo::CustomCallOp>(op)) {
5351 if (call.call_target_name() == "PadToStatic") {
5352 return EmitPadToStatic(op);
5353 }
5354 if (call.call_target_name() == "SliceToDynamic") {
5355 return EmitSliceToDynamic(op);
5356 }
5357 return EmitCustomCallThunk(op);
5358 }
5359
5360 if (mlir::isa<mlir::lmhlo_gpu::GEMMOp, mlir::lmhlo_gpu::GEMM_BiasOp>(op)) {
5361 return EmitGemmThunk(op);
5362 }
5363
5364 if (mlir::isa<mlir::lmhlo_gpu::ConvForwardOp,
5365 mlir::lmhlo_gpu::ConvForwardFusedOp,
5366 mlir::lmhlo_gpu::ConvForwardFusedSideInputOp,
5367 mlir::lmhlo_gpu::ConvBackwardFilterOp,
5368 mlir::lmhlo_gpu::ConvBackwardInputOp>(op)) {
5369 return EmitConvolutionThunk(op);
5370 }
5371
5372 if (mlir::isa<mlir::lmhlo_gpu::BatchNormTrainingOp,
5373 mlir::lmhlo_gpu::BatchNormInferenceOp,
5374 mlir::lmhlo_gpu::BatchNormGradOp>(op)) {
5375 return EmitBatchNormThunk(op);
5376 }
5377
5378 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
5379 if (mlir::isa<mlir::lmhlo_gpu::CholeskyOp>(op)) {
5380 return EmitCholeskyThunk(op);
5381 }
5382 #endif // GOOGLE_CUDA
5383
5384 if (mlir::isa<mlir::lmhlo::FftOp>(op)) {
5385 return EmitFftThunk(op);
5386 }
5387
5388 if (mlir::isa<mlir::lmhlo::TriangularSolveOp>(op)) {
5389 return EmitTriangularSolve(op);
5390 }
5391
5392 if (mlir::isa<mlir::lmhlo::FusionOp>(op)) {
5393 return EmitFusion(op);
5394 }
5395
5396 if (mlir::isa<mlir::lmhlo::SelectAndScatterOp>(op)) {
5397 return EmitSelectAndScatter(op);
5398 }
5399
5400 if (mlir::isa<mlir::lmhlo::RngGetAndUpdateStateOp>(op)) {
5401 return EmitRngGetAndUpdateState(op);
5402 }
5403
5404 if (mlir::isa<mlir::lmhlo::ScatterOp>(op)) {
5405 return EmitScatter(op);
5406 }
5407
5408 if (mlir::isa<mlir::lmhlo::SortOp>(op)) {
5409 return EmitSort(op);
5410 }
5411
5412 if (mlir::isa<mlir::lmhlo::ReplicaIdOp>(op)) {
5413 return EmitReplicaOrPartitionId<ReplicaIdThunk, mlir::lmhlo::ReplicaIdOp>(
5414 op);
5415 }
5416
5417 if (mlir::isa<mlir::lmhlo::PartitionIdOp>(op)) {
5418 return EmitReplicaOrPartitionId<PartitionIdThunk,
5419 mlir::lmhlo::PartitionIdOp>(op);
5420 }
5421
5422 if (mlir::isa<mlir::lmhlo::CollectivePermuteOp>(op)) {
5423 return EmitCollectivePermute(op);
5424 }
5425
5426 if (mlir::isa<mlir::lmhlo::AllGatherOp>(op)) {
5427 return EmitNcclThunk<NcclAllGatherThunk, mlir::lmhlo::AllGatherOp>(op);
5428 }
5429
5430 if (mlir::isa<mlir::lmhlo::AllReduceOp>(op)) {
5431 return EmitNcclThunk<NcclAllReduceThunk, mlir::lmhlo::AllReduceOp>(op);
5432 }
5433
5434 if (mlir::isa<mlir::lmhlo_gpu::AllReduceStartOp>(op)) {
5435 return EmitNcclThunk<NcclAllReduceStartThunk,
5436 mlir::lmhlo_gpu::AllReduceStartOp>(op);
5437 }
5438
5439 if (mlir::isa<mlir::lmhlo_gpu::AllReduceDoneOp>(op)) {
5440 return EmitAllReduceDone(op);
5441 }
5442
5443 if (mlir::isa<mlir::lmhlo::ReduceScatterOp>(op)) {
5444 return EmitNcclThunk<NcclReduceScatterThunk, mlir::lmhlo::ReduceScatterOp>(
5445 op);
5446 }
5447
5448 if (mlir::isa<mlir::lmhlo::AllToAllOp>(op)) {
5449 return EmitNcclThunk<NcclAllToAllThunk, mlir::lmhlo::AllToAllOp>(op);
5450 }
5451
5452 if (mlir::isa<mlir::lmhlo::InfeedOp>(op)) {
5453 return EmitInfeed(op);
5454 }
5455
5456 if (mlir::isa<mlir::lmhlo::OutfeedOp>(op)) {
5457 return EmitOutfeed(op);
5458 }
5459
5460 if (mlir::isa<mlir::lmhlo::CaseOp>(op)) {
5461 return EmitConditional(op);
5462 }
5463
5464 if (mlir::isa<mlir::lmhlo::WhileOp>(op)) {
5465 return EmitWhile(op);
5466 }
5467
5468 return InternalError("Unrecognized op: %s", MlirToString(op));
5469 }
5470
EmitLmhloRegion(mlir::Region * region)5471 Status IrEmitterUnnested::EmitLmhloRegion(mlir::Region* region) {
5472 for (mlir::Operation& op : llvm::make_early_inc_range(region->front())) {
5473 TF_RETURN_IF_ERROR(EmitOp(&op));
5474 }
5475 return Status::OK();
5476 }
5477
GetThunkInfo(mlir::Operation * op)5478 Thunk::ThunkInfo IrEmitterUnnested::GetThunkInfo(mlir::Operation* op) {
5479 auto module = op->getParentOfType<mlir::ModuleOp>();
5480 Thunk::ThunkInfo thunk_info;
5481 thunk_info.profile_annotation = absl::StrFormat(
5482 "Thunk:#hlo_op=%s,hlo_module=%s#", mlir::GetNameFromLoc(op->getLoc()),
5483 mlir::GetNameFromLoc(module->getLoc()));
5484 return thunk_info;
5485 }
5486
SetOperation(mlir::Operation * op)5487 void MlirEmitterContext::SetOperation(mlir::Operation* op) {
5488 this->name = mlir::GetNameFromLoc(op->getLoc());
5489
5490 auto operands = GetHloOperands(op);
5491 auto outputs = GetHloOutputs(op);
5492 for (auto operand : operands) {
5493 operand_shapes.push_back(GetShape(operand));
5494 }
5495 for (auto output : outputs) {
5496 output_shapes.push_back(GetShape(output));
5497 }
5498 }
5499
5500 } // namespace gpu
5501 } // namespace xla
5502