1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h"
17
18 #include <algorithm>
19 #include <cstring>
20 #include <iterator>
21 #include <memory>
22 #include <string>
23 #include <type_traits>
24 #include <vector>
25
26 #include "absl/algorithm/container.h"
27 #include "absl/container/inlined_vector.h"
28 #include "absl/memory/memory.h"
29 #include "absl/strings/str_cat.h"
30 #include "absl/strings/str_format.h"
31 #include "absl/types/optional.h"
32 #include "absl/types/span.h"
33 #include "llvm/ADT/APInt.h"
34 #include "llvm/ADT/SetVector.h"
35 #include "llvm/ADT/StringRef.h"
36 #include "llvm/IR/BasicBlock.h"
37 #include "llvm/IR/Function.h"
38 #include "llvm/IR/IRBuilder.h"
39 #include "llvm/IR/Instructions.h"
40 #include "llvm/IR/LLVMContext.h"
41 #include "llvm/IR/Module.h"
42 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
43 #include "mlir/IR/Attributes.h" // from @llvm-project
44 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
45 #include "mlir/IR/Builders.h" // from @llvm-project
46 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
47 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
48 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
49 #include "mlir/IR/Verifier.h" // from @llvm-project
50 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_gpu_ops.h"
51 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
52 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/hlo_utils.h"
53 #include "tensorflow/compiler/mlir/utils/name_utils.h"
54 #include "tensorflow/compiler/mlir/xla/attribute_exporter.h"
55 #include "tensorflow/compiler/mlir/xla/hlo_function_importer.h"
56 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
57 #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
58 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
59 #include "tensorflow/compiler/xla/layout_util.h"
60 #include "tensorflow/compiler/xla/literal.h"
61 #include "tensorflow/compiler/xla/permutation_util.h"
62 #include "tensorflow/compiler/xla/service/buffer_assignment.h"
63 #include "tensorflow/compiler/xla/service/custom_call_target_registry.h"
64 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
65 #include "tensorflow/compiler/xla/service/gpu/backend_configs.pb.h"
66 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
67 #include "tensorflow/compiler/xla/service/gpu/collective_permute_thunk.h"
68 #include "tensorflow/compiler/xla/service/gpu/conditional_thunk.h"
69 #include "tensorflow/compiler/xla/service/gpu/convolution_thunk.h"
70 #include "tensorflow/compiler/xla/service/gpu/copy_thunk.h"
71 #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_thunk.h"
72 #include "tensorflow/compiler/xla/service/gpu/custom_call_thunk.h"
73 #include "tensorflow/compiler/xla/service/gpu/fft_thunk.h"
74 #include "tensorflow/compiler/xla/service/gpu/for_thunk.h"
75 #include "tensorflow/compiler/xla/service/gpu/gemm_thunk.h"
76 #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h"
77 #include "tensorflow/compiler/xla/service/gpu/gpu_conv_runner.h"
78 #include "tensorflow/compiler/xla/service/gpu/hlo_to_ir_bindings.h"
79 #include "tensorflow/compiler/xla/service/gpu/infeed_thunk.h"
80 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
81 #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
82 #include "tensorflow/compiler/xla/service/gpu/kernel_mapping_scheme.h"
83 #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h"
84 #include "tensorflow/compiler/xla/service/gpu/launch_dimensions.h"
85 #include "tensorflow/compiler/xla/service/gpu/memset_thunk.h"
86 #include "tensorflow/compiler/xla/service/gpu/nccl_all_gather_thunk.h"
87 #include "tensorflow/compiler/xla/service/gpu/nccl_all_reduce_thunk.h"
88 #include "tensorflow/compiler/xla/service/gpu/nccl_all_to_all_thunk.h"
89 #include "tensorflow/compiler/xla/service/gpu/outfeed_thunk.h"
90 #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
91 #include "tensorflow/compiler/xla/service/gpu/replica_id_thunk.h"
92 #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
93 #include "tensorflow/compiler/xla/service/gpu/target_util.h"
94 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
95 #include "tensorflow/compiler/xla/service/gpu/triangular_solve_thunk.h"
96 #include "tensorflow/compiler/xla/service/gpu/while_thunk.h"
97 #include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
98 #include "tensorflow/compiler/xla/service/hlo_computation.h"
99 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
100 #include "tensorflow/compiler/xla/service/hlo_instructions.h"
101 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
102 #include "tensorflow/compiler/xla/service/llvm_ir/buffer_assignment_util.h"
103 #include "tensorflow/compiler/xla/service/llvm_ir/dynamic_update_slice_util.h"
104 #include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
105 #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
106 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
107 #include "tensorflow/compiler/xla/service/llvm_ir/sort_util.h"
108 #include "tensorflow/compiler/xla/service/llvm_ir/tuple_ops.h"
109 #include "tensorflow/compiler/xla/service/name_uniquer.h"
110 #include "tensorflow/compiler/xla/service/pattern_matcher.h"
111 #include "tensorflow/compiler/xla/service/shape_inference.h"
112 #include "tensorflow/compiler/xla/service/while_loop_analysis.h"
113 #include "tensorflow/compiler/xla/shape_util.h"
114 #include "tensorflow/compiler/xla/status_macros.h"
115 #include "tensorflow/compiler/xla/types.h"
116 #include "tensorflow/compiler/xla/union_find.h"
117 #include "tensorflow/compiler/xla/util.h"
118 #include "tensorflow/compiler/xla/window_util.h"
119 #include "tensorflow/compiler/xla/xla_data.pb.h"
120 #include "tensorflow/core/lib/core/bits.h"
121 #include "tensorflow/core/lib/core/status.h"
122 #include "tensorflow/core/platform/errors.h"
123 #include "tensorflow/core/platform/logging.h"
124
125 #if GOOGLE_CUDA
126 #include "tensorflow/compiler/xla/service/gpu/cholesky_thunk.h"
127 #endif // GOOGLE_CUDA
128
129 namespace xla {
130 namespace gpu {
131
132 namespace {
133
134 using absl::InlinedVector;
135 using absl::nullopt;
136 using absl::optional;
137 using absl::StrCat;
138 using llvm_ir::IrArray;
139 using llvm_ir::IrName;
140
141 const auto kDimX = KernelMappingScheme::DimX;
142 const auto kDimY = KernelMappingScheme::DimY;
143 const auto kDimZ = KernelMappingScheme::DimZ;
144 const auto kDimTot = KernelMappingScheme::DimTot;
145
146 const auto kLinearIndexingX = KernelMappingScheme::LinearIndexingX;
147 const auto kStridedIndexingX = KernelMappingScheme::StridedIndexingX;
148 const auto kStridedLinearIndexingX =
149 KernelMappingScheme::StridedLinearIndexingX;
150
151 // If a dimensions is smaller than this, untiled transposition may be more
152 // efficient.
153 const int64 kMinDimensionToTransposeTiled = 16;
154
155 // Updates the launch dimensions in "thunk" and annotate the launch dimensions
156 // of the corresponding IR kernel in "llvm_module".
157 // Precondition: "thunk" must be a KernelThunk.
UpdateLaunchDimensions(const LaunchDimensions & launch_dims,Thunk * thunk,llvm::Module * llvm_module)158 void UpdateLaunchDimensions(const LaunchDimensions& launch_dims, Thunk* thunk,
159 llvm::Module* llvm_module) {
160 CHECK(Thunk::Kind::kKernel == thunk->kind());
161 KernelThunk* kernel_thunk = static_cast<KernelThunk*>(thunk);
162 kernel_thunk->SetLaunchDimensions(launch_dims);
163
164 // Add __launch_bounds__ to metadata. This limits registers per thread to
165 // avoid out-of-resources launching errors.
166 llvm::NamedMDNode* nvvm_annotations_node =
167 llvm_module->getOrInsertNamedMetadata("nvvm.annotations");
168 llvm::Function* ir_kernel =
169 llvm_module->getFunction(kernel_thunk->kernel_name().c_str());
170 llvm::LLVMContext& llvm_context = llvm_module->getContext();
171 llvm::ConstantInt* threads_per_block_ir_value = llvm::ConstantInt::get(
172 llvm::IntegerType::get(llvm_context, /*NumBits=*/32),
173 launch_dims.thread_counts_per_block().x);
174 // Our launch bounds are exact, so we can specify them as reqntidx rather than
175 // maxntidx.
176 nvvm_annotations_node->addOperand(llvm::MDNode::get(
177 llvm_context,
178 {llvm::ConstantAsMetadata::get(ir_kernel),
179 llvm::MDString::get(llvm_context, "reqntidx"),
180 llvm::ConstantAsMetadata::get(threads_per_block_ir_value)}));
181 }
182
BinarySearchDenseElementsAttr(mlir::DenseIntElementsAttr elements,int64 v)183 bool BinarySearchDenseElementsAttr(mlir::DenseIntElementsAttr elements,
184 int64 v) {
185 mlir::APInt value(sizeof(int64) * 8, v, /*isSigned=*/true);
186 return std::binary_search(
187 elements.begin(), elements.end(), value,
188 [](const mlir::APInt& x, const mlir::APInt& y) { return x.slt(y); });
189 }
190
191 // Returns true if the fusion contains any instruction that is likely
192 // translated to complex LLVM IR, such as loops, and prevent vectorization.
MayPreventVectorization(const HloInstruction & hlo)193 bool MayPreventVectorization(const HloInstruction& hlo) {
194 if (hlo.opcode() == HloOpcode::kFusion) {
195 return absl::c_any_of(hlo.fused_instructions_computation()->instructions(),
196 [](const HloInstruction* instr) {
197 switch (instr->opcode()) {
198 case HloOpcode::kReduceWindow:
199 case HloOpcode::kSort:
200 case HloOpcode::kDot:
201 case HloOpcode::kSin:
202 case HloOpcode::kCos:
203 case HloOpcode::kPower:
204 case HloOpcode::kAtan2:
205 return true;
206 case HloOpcode::kReduce:
207 return !instr->shape().IsArray();
208 default:
209 return false;
210 }
211 });
212 } else if (hlo.IsElementwise()) {
213 // Unfused elementwise operations are usually memory bound, unroll them.
214 switch (hlo.opcode()) {
215 // The following elementwise operation implementations contain branches.
216 // LLVM vectorizer doesn't work in that case.
217 // The unrolled code is faster when it isn't vectorized.
218 case HloOpcode::kSin:
219 case HloOpcode::kCos:
220 case HloOpcode::kPower:
221 case HloOpcode::kAtan2:
222 return true;
223 default:
224 return false;
225 }
226 } else if (hlo.opcode() == HloOpcode::kReduce && hlo.shape().IsArray()) {
227 // TODO(timshen): check if the to_apply() attribute contains instructions
228 // that break LLVM vectorization.
229 return false;
230 }
231 return true;
232 }
233
LmhloOpIsElementwise(mlir::Operation * op)234 bool LmhloOpIsElementwise(mlir::Operation* op) {
235 CHECK(op->getDialect() == op->getContext()->getLoadedDialect("lmhlo"));
236 auto opcode = *MhloToHloOpcode(op);
237 if (HloInstruction::IsOpElementwise(opcode)) {
238 return true;
239 }
240 if (opcode == HloOpcode::kMap) {
241 int iota = 0;
242 for (const llvm::APInt& i :
243 mlir::cast<mlir::lmhlo::MapOp>(op).dimensions()) {
244 if (i.getZExtValue() != iota) {
245 return false;
246 }
247 iota++;
248 }
249 return true;
250 }
251 // TODO(timshen): not sure about whether porting
252 // HloFusionInstruction::IsElementwiseImpl() is necessary. HandleFusion()
253 // doesn't use such information.
254 return false;
255 }
256
MayPreventVectorization(mlir::Operation * op)257 bool MayPreventVectorization(mlir::Operation* op) {
258 CHECK(op->getDialect() == op->getContext()->getLoadedDialect("lmhlo"));
259 auto opcode = *MhloToHloOpcode(op);
260
261 if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
262 for (mlir::Operation& instr : fusion.region().front()) {
263 if (mlir::isa<mlir::lmhlo::TerminatorOp, mlir::mhlo::ReturnOp,
264 mlir::TensorLoadOp, mlir::TensorStoreOp>(&instr)) {
265 continue;
266 }
267 CHECK(instr.getDialect() == instr.getContext()->getLoadedDialect("mhlo"))
268 << MlirToString(op);
269 switch (*MhloToHloOpcode(&instr)) {
270 case HloOpcode::kReduceWindow:
271 case HloOpcode::kSort:
272 case HloOpcode::kDot:
273 case HloOpcode::kSin:
274 case HloOpcode::kCos:
275 case HloOpcode::kPower:
276 case HloOpcode::kAtan2:
277 return true;
278 case HloOpcode::kReduce:
279 if (instr.getNumResults() > 1) {
280 return true;
281 }
282 break;
283 default:
284 break;
285 }
286 }
287 return false;
288 } else if (LmhloOpIsElementwise(op)) {
289 // Unfused elementwise operations are usually memory bound, unroll them.
290 switch (opcode) {
291 // The following elementwise operation implementations contain branches.
292 // LLVM vectorizer doesn't work in that case.
293 // The unrolled code is faster when it isn't vectorized.
294 case HloOpcode::kSin:
295 case HloOpcode::kCos:
296 case HloOpcode::kPower:
297 case HloOpcode::kAtan2:
298 return true;
299 default:
300 return false;
301 }
302 } else if (opcode == HloOpcode::kReduce && GetHloOutputs(op).size() == 1) {
303 // TODO(timshen): check if the to_apply() attribute contains instructions
304 // that break LLVM vectorization.
305 return false;
306 }
307 return true;
308 }
309
GetOutputOps(mlir::lmhlo::FusionOp fusion)310 std::vector<mlir::Operation*> GetOutputOps(mlir::lmhlo::FusionOp fusion) {
311 llvm::SetVector<mlir::Operation*> ops;
312 for (mlir::Value output_value : fusion.getFusionResults()) {
313 ops.insert(output_value.getDefiningOp());
314 }
315 return std::vector<mlir::Operation*>(ops.begin(), ops.end());
316 }
317
318 // Computes the maximum valid unroll factor for a given instruction.
ComputeMaxUnrollFactor(const Shape & shape,const HloModuleConfig & hlo_module_config)319 int ComputeMaxUnrollFactor(const Shape& shape,
320 const HloModuleConfig& hlo_module_config) {
321 int max_unroll_factor =
322 hlo_module_config.debug_options().xla_gpu_max_kernel_unroll_factor();
323
324 // Find the largest possible power of two to unroll by.
325 // TODO(kramerb): Make this smarter.
326 int64 num_elements = ShapeUtil::ElementsIn(shape);
327 for (int i = max_unroll_factor; i > 1; i /= 2) {
328 if (num_elements % i == 0) {
329 return i;
330 }
331 }
332
333 // Cannot unroll.
334 return 1;
335 }
336
337 // Computes the maximum valid unroll factor for a given instruction.
ComputeMaxUnrollFactor(const HloInstruction * hlo)338 int ComputeMaxUnrollFactor(const HloInstruction* hlo) {
339 const Shape& element_shape = hlo->IsMultiOutputFusion()
340 ? ShapeUtil::GetSubshape(hlo->shape(), {0})
341 : hlo->shape();
342 return ComputeMaxUnrollFactor(element_shape, hlo->GetModule()->config());
343 }
344
345 // Computes the maximum valid unroll factor for a given instruction.
ComputeMaxUnrollFactor(mlir::Operation * op,const HloModuleConfig & hlo_module_config)346 int ComputeMaxUnrollFactor(mlir::Operation* op,
347 const HloModuleConfig& hlo_module_config) {
348 Shape element_shape = [&] {
349 std::vector<Shape> shapes;
350 // Detect multi-output fusion. Notice that for a reduce in the fusion that
351 // returns a tuple, we don't want to treat it as multi-output fusion. We
352 // want to pass that tuple into ComputeMaxUnrollFactor below. For an actual
353 // MOF, just pass the first element of the root tuple.
354 if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
355 std::vector<mlir::Operation*> fusion_outputs = GetOutputOps(fusion);
356 for (mlir::Value result : fusion_outputs[0]->getResults()) {
357 shapes.push_back(TypeToShape(result.getType()));
358 }
359 } else {
360 for (mlir::Value result : GetHloOutputs(op)) {
361 shapes.push_back(TypeToShape(result.getType()));
362 }
363 }
364 if (shapes.size() > 1) {
365 return ShapeUtil::MakeTupleShape(shapes);
366 }
367 return shapes[0];
368 }();
369 return ComputeMaxUnrollFactor(element_shape, hlo_module_config);
370 }
371
372 // Returns the llvm type for the indices used in the kernel that contains the
373 // hlo instruction. Such indices include the index for the parallel loop and
374 // the indices for the tensors accessed by the kernel. The return type is i32
375 // iff the following conditions are met:
376 // . The launch_size of the kernel is within the range of i32.
377 // . The sizes of all the tensors accessed within the kernel are within the
378 // range of i32.
379 // Otherwise, the return type is i64.
GetIndexTypeForKernel(const HloInstruction * hlo,int64 launch_size,llvm::IRBuilder<> * b)380 llvm::Type* GetIndexTypeForKernel(const HloInstruction* hlo, int64 launch_size,
381 llvm::IRBuilder<>* b) {
382 // Find the unnested hlo instruction for which the kernel is generated for.
383 const HloInstruction* unnested_hlo = hlo;
384 const HloComputation* computation = hlo->parent();
385 if (computation->IsFusionComputation()) {
386 unnested_hlo = computation->FusionInstruction();
387 }
388
389 auto shape_in_range = [&](const Shape& s) {
390 bool in_range = true;
391 ShapeUtil::ForEachSubshape(s, [&](const Shape& sub_shape,
392 const ShapeIndex& /*index*/) {
393 if (sub_shape.IsArray() && !IsInt32(ShapeUtil::ElementsIn(sub_shape))) {
394 in_range = false;
395 }
396 });
397
398 return in_range;
399 };
400
401 llvm::Type* i64_ty = b->getInt64Ty();
402 // Check launch dimension
403 if (!IsInt32(launch_size)) {
404 return i64_ty;
405 }
406
407 // Check the size of result tensors
408 if (!shape_in_range(unnested_hlo->shape())) {
409 return i64_ty;
410 }
411
412 auto hlo_shape_in_range = [&](const HloInstruction* operand) -> bool {
413 return shape_in_range(operand->shape());
414 };
415
416 // Check the size of input tensors
417 if (!absl::c_all_of(unnested_hlo->operands(), hlo_shape_in_range)) {
418 return i64_ty;
419 }
420
421 // Check the size of the internal result tensors
422 if (unnested_hlo->opcode() == HloOpcode::kFusion) {
423 if (!absl::c_all_of(
424 unnested_hlo->fused_instructions_computation()->instructions(),
425 hlo_shape_in_range)) {
426 return i64_ty;
427 }
428 }
429
430 return b->getInt32Ty();
431 }
432
433 // The same as GetIndexTypeForKernel, but works with MLIR ops.
GetIndexTypeForKernelFromMlir(mlir::Operation * op,int64 launch_size,llvm::IRBuilder<> * b)434 llvm::Type* GetIndexTypeForKernelFromMlir(mlir::Operation* op,
435 int64 launch_size,
436 llvm::IRBuilder<>* b) {
437 auto shape_in_range = [&](const Shape& s) {
438 bool in_range = true;
439 ShapeUtil::ForEachSubshape(s, [&](const Shape& sub_shape,
440 const ShapeIndex& /*index*/) {
441 if (sub_shape.IsArray() && !IsInt32(ShapeUtil::ElementsIn(sub_shape))) {
442 in_range = false;
443 }
444 });
445
446 return in_range;
447 };
448
449 llvm::Type* i64_ty = b->getInt64Ty();
450 // Check launch dimension
451 if (!IsInt32(launch_size)) {
452 return i64_ty;
453 }
454
455 // Check the size of result tensors
456 for (auto result : GetHloOutputs(op)) {
457 if (!shape_in_range(TypeToShape(result.getType()))) {
458 return i64_ty;
459 }
460 }
461
462 auto hlo_shape_in_range = [&](mlir::Value operand) -> bool {
463 return shape_in_range(TypeToShape(operand.getType()));
464 };
465
466 // Check the size of input tensors
467 if (!absl::c_all_of(op->getOperands(), hlo_shape_in_range)) {
468 return i64_ty;
469 }
470
471 // Check the size of the internal result tensors
472 if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
473 auto result = fusion.region().walk([&](mlir::Operation* op) {
474 for (mlir::Value result : op->getResults()) {
475 if (!hlo_shape_in_range(result)) {
476 return mlir::WalkResult::interrupt();
477 }
478 }
479 return mlir::WalkResult::advance();
480 });
481 if (result.wasInterrupted()) {
482 return i64_ty;
483 }
484 }
485
486 return b->getInt32Ty();
487 }
488
489 // Gets the input shape of the ROOT slices, which will be used as the kernel
490 // launch dims. The slice input fusion requires the input shapes of the ROOT
491 // slices to be the same although the (slice) output shapes can be different.
492 //
493 // Returns the input shape of the ROOT slices if all the input shapes of ROOT
494 // slices are the same and the slices are non-strided. Otherwise, returns
495 // FailedPrecondition.
GetConsistentInputShapeForRootSlices(mlir::lmhlo::FusionOp fusion)496 StatusOr<Shape> GetConsistentInputShapeForRootSlices(
497 mlir::lmhlo::FusionOp fusion) {
498 if (!IsInputFusibleSlices(fusion, /*verify_no_strides=*/true)) {
499 return FailedPrecondition(
500 "Unsupported root for slice input fusion. "
501 "Only non-strided slices are supported.");
502 }
503
504 absl::optional<Shape> first_slice_operand_shape;
505 for (mlir::Value result : fusion.getFusionResults()) {
506 auto slice =
507 mlir::dyn_cast_or_null<mlir::mhlo::SliceOp>(result.getDefiningOp());
508 if (!slice) {
509 return FailedPrecondition("Expected a slice op");
510 }
511 if (first_slice_operand_shape.has_value()) {
512 Shape operand_shape = TypeToShape(slice.operand().getType());
513 if (!ShapeUtil::EqualIgnoringElementType(*first_slice_operand_shape,
514 operand_shape)) {
515 return FailedPrecondition(
516 "Fused slices do not have the same input shape, instruction is %s",
517 MlirToString(fusion));
518 }
519 } else {
520 first_slice_operand_shape = TypeToShape(slice.operand().getType());
521 }
522 }
523 if (!first_slice_operand_shape.has_value()) {
524 return InvalidArgument("Fusion has no roots");
525 }
526 return *first_slice_operand_shape;
527 }
528
529 } // namespace
530
IrEmitterUnnested(const HloModuleConfig & hlo_module_config,const HloComputation * hlo_computation,IrEmitterContext * ir_emitter_context)531 IrEmitterUnnested::IrEmitterUnnested(const HloModuleConfig& hlo_module_config,
532 const HloComputation* hlo_computation,
533 IrEmitterContext* ir_emitter_context)
534 : IrEmitter(hlo_module_config, ir_emitter_context, /*is_nested=*/false) {}
535
Create(const HloModuleConfig & hlo_module_config,const HloComputation * hlo_computation,IrEmitterContext * ir_emitter_context)536 StatusOr<std::unique_ptr<IrEmitterUnnested>> IrEmitterUnnested::Create(
537 const HloModuleConfig& hlo_module_config,
538 const HloComputation* hlo_computation,
539 IrEmitterContext* ir_emitter_context) {
540 auto emitter = std::unique_ptr<IrEmitterUnnested>(new IrEmitterUnnested(
541 hlo_module_config, hlo_computation, ir_emitter_context));
542 if (hlo_computation) {
543 emitter->mlir_scratch_module_.emplace(mlir::ModuleOp::create(
544 mlir::Builder(ir_emitter_context->mlir_context()).getUnknownLoc()));
545 emitter->lhlo_scratch_emitter_.emplace(
546 emitter->ir_emitter_context_->buffer_assignment(), *hlo_computation,
547 emitter->mlir_scratch_module_->get());
548 TF_RETURN_IF_ERROR(emitter->lhlo_scratch_emitter_->Initialize());
549 }
550 return std::move(emitter);
551 }
552
Postprocess(HloInstruction * hlo)553 Status IrEmitterUnnested::Postprocess(HloInstruction* hlo) {
554 bindings_.UnbindAllLocalIrValues();
555 return DfsHloVisitor::Postprocess(hlo);
556 }
557
BuildKernelPrototype(absl::string_view name,absl::Span<const BufferAllocation * const> args)558 llvm::Function* IrEmitterUnnested::BuildKernelPrototype(
559 absl::string_view name, absl::Span<const BufferAllocation* const> args) {
560 // Compute the kernel name. The opcode string may contain "-" which cannot be
561 // in a PTX function name, so sanitize the name before uniquifying it.
562 string kernel_name = ir_emitter_context_->name_uniquer()->GetUniqueName(
563 llvm_ir::SanitizeFunctionName(std::string(name)));
564
565 // Create the kernel and add it to the module.
566 llvm::Module* module = ir_emitter_context_->llvm_module();
567 llvm::LLVMContext& context = module->getContext();
568 llvm::FunctionType* kernel_type = llvm::FunctionType::get(
569 /*Result=*/llvm::Type::getVoidTy(context),
570 std::vector<llvm::Type*>(args.size(), b_.getInt8PtrTy()),
571 /*isVarArg=*/false);
572 llvm::Function* kernel =
573 llvm::Function::Create(kernel_type, llvm::GlobalValue::ExternalLinkage,
574 kernel_name.c_str(), module);
575
576 // Add dereferenceable and alignment information to each of the kernel's
577 // parameters.
578 auto arg_it = kernel->arg_begin();
579 for (size_t arg_no = 0; arg_no < args.size(); ++arg_no) {
580 const BufferAllocation* alloc = args[arg_no];
581 llvm::Argument* fn_arg = &*arg_it;
582 ++arg_it;
583
584 kernel->addDereferenceableAttr(arg_no + 1, alloc->size());
585
586 const int64 alignment = [&] {
587 if (alloc->is_entry_computation_parameter()) {
588 return kEntryParameterAlignBytes;
589 } else if (alloc->is_constant()) {
590 return kConstantBufferAlignBytes;
591 } else {
592 return kXlaAllocatedBufferAlignBytes;
593 }
594 }();
595
596 kernel->addParamAttr(
597 arg_no,
598 llvm::Attribute::get(context, llvm::Attribute::Alignment, alignment));
599
600 if (alloc->IsPreallocatedTempBuffer()) {
601 fn_arg->setName("temp_buf");
602 } else {
603 fn_arg->setName(StrCat("alloc", alloc->index()));
604 }
605 }
606
607 AnnotateFunctionAsGpuKernel(module, kernel, &b_);
608
609 // TODO(b/65380986): Investigate if adding fast math flags for generated
610 // kernels makes sense.
611
612 // Update the insert point to the entry basic block.
613 llvm::BasicBlock* entry_bb =
614 llvm::BasicBlock::Create(context, /*Name=*/"entry", /*Parent=*/kernel);
615
616 // Emit a "return void" at entry_bb's end, and set the insert point before
617 // that return instruction.
618 b_.SetInsertPoint(llvm::ReturnInst::Create(context, entry_bb));
619
620 return kernel;
621 }
622
GetAllocationSliceForMlir(mlir::Value v)623 StatusOr<BufferAllocation::Slice> IrEmitterUnnested::GetAllocationSliceForMlir(
624 mlir::Value v) {
625 return xla::gpu::GetAllocationSliceForMlir(
626 v, ir_emitter_context_->allocations());
627 }
628
DefaultAction(HloInstruction * hlo)629 Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) {
630 TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
631 return EmitUsingElementalIrEmitter(input);
632 }
633
EmitUsingElementalIrEmitter(MlirEmitterInput input)634 Status IrEmitterUnnested::EmitUsingElementalIrEmitter(MlirEmitterInput input) {
635 // Replace unnested op with a fused nested op.
636 //
637 // TODO(timshen): Ultimately this should be a pass. It's currently not a pass,
638 // because we don't have a fully functioning LMHLO graph yet.
639
640 mlir::Location loc = input.op->getLoc();
641 mlir::lmhlo::FusionOp fusion =
642 mlir::OpBuilder(input.op).create<mlir::lmhlo::FusionOp>(loc);
643 Shape output_shape;
644 mlir::OpBuilder b(&fusion.region());
645
646 const auto load_memrefs = [loc, &b](mlir::ValueRange range) {
647 std::vector<mlir::Value> operands;
648 for (mlir::Value memref : range) {
649 auto load = b.create<mlir::TensorLoadOp>(loc, memref);
650 HloFunctionImporter::SetLayoutForMlir(load,
651 TypeToShape(memref.getType()));
652 operands.push_back(load);
653 }
654 return operands;
655 };
656
657 if (auto copy = mlir::dyn_cast<mlir::lmhlo::CopyOp>(input.op)) {
658 auto operand = b.create<mlir::TensorLoadOp>(loc, copy.operand());
659 HloFunctionImporter::SetLayoutForMlir(
660 operand, TypeToShape(copy.operand().getType()));
661 auto fused_copy = b.create<mlir::mhlo::CopyOp>(loc, operand);
662 output_shape = TypeToShape(copy.output().getType());
663 HloFunctionImporter::SetLayoutForMlir(fused_copy, output_shape);
664 b.create<mlir::TensorStoreOp>(loc, fused_copy, copy.output());
665 } else if (auto reduce = mlir::dyn_cast<mlir::lmhlo::ReduceOp>(input.op)) {
666 std::vector<mlir::Value> operands = load_memrefs(reduce.operands());
667 std::vector<mlir::Value> init_values = load_memrefs(reduce.init_values());
668 auto fused_reduce = b.create<mlir::mhlo::ReduceOp>(
669 loc, operands, init_values, reduce.dimensions());
670 fused_reduce.body().takeBody(reduce.body());
671 CHECK_EQ(fused_reduce.getNumResults(), reduce.out().size());
672 std::vector<Shape> output_shapes;
673 for (int i = 0; i < reduce.out().size(); i++) {
674 b.create<mlir::TensorStoreOp>(loc, fused_reduce.getResult(i),
675 reduce.out()[i]);
676 auto shape = TypeToShape(reduce.out()[i].getType());
677 if (i == 0) {
678 HloFunctionImporter::SetLayoutForMlir(fused_reduce, shape);
679 }
680 output_shapes.push_back(shape);
681 }
682 if (output_shapes.size() == 1) {
683 output_shape = output_shapes[0];
684 } else {
685 output_shape = ShapeUtil::MakeTupleShape(output_shapes);
686 }
687 } else {
688 // Try to generically convert any LMHLO ops to LMHLO fusion + the
689 // corresponding MHLO op. Currently we've only looked at elementwise ops and
690 // they seem to be well covered.
691 //
692 // TODO(timshen): Moving forward, we should make it cover all ops if
693 // possible, and only special-case the ones it can't.
694 std::vector<mlir::Value> outputs;
695 mlir::Operation* new_op;
696 {
697 auto operands = GetHloOperands(input.op);
698 outputs = GetHloOutputs(input.op);
699 TF_RET_CHECK(outputs.size() == 1) << MlirToString(input.op);
700
701 std::vector<mlir::Value> loads = load_memrefs(operands);
702 std::string mhlo_op_name = mlir::hlo::LmhloToMhloOpName(
703 input.op->getName().getStringRef(), input.op->getContext());
704 TF_RET_CHECK(!mhlo_op_name.empty())
705 << "No corresponding MHLO op for given LMHLO op: "
706 << MlirToString(input.op);
707 mlir::OperationState op_state(loc, mhlo_op_name);
708
709 mlir::BlockAndValueMapping mapper;
710 for (mlir::Region& region : input.op->getRegions()) {
711 mlir::Region* new_region = op_state.addRegion();
712 region.cloneInto(new_region, mapper);
713 }
714
715 op_state.addOperands(loads);
716 op_state.addAttributes(input.op->getAttrs());
717 op_state.addTypes({mlir::RankedTensorType::get(
718 outputs[0].getType().cast<mlir::MemRefType>().getShape(),
719 outputs[0].getType().cast<mlir::MemRefType>().getElementType())});
720 new_op = b.createOperation(op_state);
721 }
722 TF_RET_CHECK(mlir::succeeded(mlir::verify(new_op)));
723 output_shape = TypeToShape(outputs[0].getType());
724 HloFunctionImporter::SetLayoutForMlir(new_op, output_shape);
725 b.create<mlir::TensorStoreOp>(loc, new_op->getResult(0), outputs[0]);
726 }
727 int unroll_factor = 1;
728 if (!MayPreventVectorization(input.op)) {
729 unroll_factor = ComputeMaxUnrollFactor(input.op, hlo_module_config_);
730 }
731 input.op->erase();
732 input.op = fusion;
733 return EmitLoopFusionFromMlir(input, output_shape, unroll_factor);
734 }
735
HandleConstant(HloInstruction * constant)736 Status IrEmitterUnnested::HandleConstant(HloInstruction* constant) {
737 TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(constant));
738 return EmitConstant(input);
739 }
740
EmitConstant(MlirEmitterInput mlir_input)741 Status IrEmitterUnnested::EmitConstant(MlirEmitterInput mlir_input) {
742 auto get_global = mlir::cast<mlir::GetGlobalMemrefOp>(mlir_input.op);
743 auto module = get_global->getParentOfType<mlir::ModuleOp>();
744 auto global =
745 mlir::cast<mlir::GlobalMemrefOp>(module.lookupSymbol(get_global.name()));
746
747 auto literal = global.initial_value()->dyn_cast<mlir::DenseElementsAttr>();
748 TF_RET_CHECK(literal);
749
750 const bool should_emit_initializer = literal.getType().getNumElements() <= 1;
751
752 TF_ASSIGN_OR_RETURN(int element_bytes,
753 GetElementTypeBytes(literal.getType().getElementType()));
754 llvm::ArrayType* global_type = llvm::ArrayType::get(
755 b_.getInt8Ty(), literal.getType().getNumElements() * element_bytes);
756
757 GpuExecutable::ConstantInfo info;
758 llvm::Constant* initializer;
759 if (should_emit_initializer) {
760 std::vector<uint8> content;
761 TF_RETURN_IF_ERROR(CopyDenseElementsDataToXlaFormat(literal, &content));
762 initializer = llvm::ConstantDataArray::get<uint8>(
763 ir_emitter_context_->llvm_module()->getContext(), content);
764 } else {
765 TF_RETURN_IF_ERROR(
766 CopyDenseElementsDataToXlaFormat(literal, &info.content));
767 initializer = llvm::ConstantAggregateZero::get(global_type);
768 }
769
770 // These globals will be looked up by name by GpuExecutable so we need to
771 // give them an external linkage. Not all of their uses are visible in
772 // the LLVM IR so we can't give then a linkage that merely preserves their
773 // names (like available_externally), we also need to ensure that they stick
774 // around even if they're "unused".
775 //
776 // We may have to be more clever here in the future if we notice that we're
777 // keeping around too many globals because of their linkage.
778 unsigned global_address_space =
779 llvm_ir::GetGlobalMemoryAddressSpace(*ir_emitter_context_->llvm_module());
780
781 llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable(
782 global_type, /*isConstant=*/should_emit_initializer,
783 llvm::GlobalValue::ExternalLinkage,
784 /*Initializer=*/initializer, global.sym_name(),
785 /*TLMode=*/llvm::GlobalValue::NotThreadLocal,
786 /*AddressSpace=*/global_address_space,
787 /*isExternallyInitialized=*/false);
788 global_for_const->setAlignment(llvm::Align(kConstantBufferAlignBytes));
789 ir_emitter_context_->llvm_module()->getGlobalList().push_back(
790 global_for_const);
791
792 info.symbol_name.assign(global.sym_name().begin(), global.sym_name().end());
793
794 info.allocation_index =
795 global->getAttrOfType<mlir::IntegerAttr>("lmhlo.alloc").getInt();
796 ir_emitter_context_->constants().push_back(std::move(info));
797 return Status::OK();
798 }
799
HandleConditional(HloInstruction * conditional)800 Status IrEmitterUnnested::HandleConditional(HloInstruction* conditional) {
801 TF_ASSIGN_OR_RETURN(auto thunk, BuildConditionalThunk(conditional));
802 AddThunkToThunkSequence(std::move(thunk));
803 return Status::OK();
804 }
805
HandleConvolution(HloInstruction * convolution)806 Status IrEmitterUnnested::HandleConvolution(HloInstruction* convolution) {
807 AddThunkToThunkSequence(
808 BuildKernelThunk(convolution, /*implements_whole_instruction=*/true));
809 return IrEmitter::HandleConvolution(convolution);
810 }
811
812 // Input = {dynamic array(with dynamic dimension meta data at the end)}
813 // Output = {static array, dynamic_dim0, dynamic_dim1}
EmitPadToStaticFromMlir(MlirEmitterInput mlir_input)814 Status IrEmitterUnnested::EmitPadToStaticFromMlir(MlirEmitterInput mlir_input) {
815 // TODO(jurahul): Create an op to represent PadToStatic.
816 auto pad_to_static = mlir::cast<mlir::lmhlo::CustomCallOp>(mlir_input.op);
817 int unroll_factor = 1;
818 std::string ir_name = mlir::GetNameFromLoc(pad_to_static.getLoc());
819
820 std::vector<llvm_ir::IrArray> ir_arrays;
821 TF_ASSIGN_OR_RETURN(
822 auto kernel_thunk,
823 BuildKernelThunkForMlir(pad_to_static, mlir_input.thunk_info,
824 mlir_input.extra_slice, &ir_arrays));
825
826 const llvm_ir::IrArray source_array = ir_arrays[0];
827 const llvm_ir::IrArray output_array = ir_arrays[1];
828 auto output_dim_arrays =
829 absl::Span<const llvm_ir::IrArray>(ir_arrays).subspan(2);
830
831 // pseudo code for PadToStatic on a 2d array
832 // int* source_array = input[0];
833 // int* dest_array = output[0];
834 const Shape& data_shape =
835 TypeToShape(pad_to_static.output().front().getType());
836 const Shape& input_shape =
837 TypeToShape(pad_to_static.args().front().getType());
838 llvm::Value* source_buffer = source_array.GetBasePointer();
839 llvm::Value* raw_buffer =
840 b_.CreateBitCast(source_buffer, b_.getInt8Ty()->getPointerTo());
841
842 // TODO(jurahul): input_shape here is the static shape of the input (which has
843 // a dynamic shape in XLA). Currently, we are mapping that to a static shaped
844 // memref. When we change that to a more appropriate representation in MLIR,
845 // fix this code to correctly deduce the static shape backing the dynamically
846 // shaped memref.
847 int64 raw_data_size = ShapeUtil::ByteSizeOf(input_shape);
848
849 // int* dyn_dim0_size = source_array + meta_data_offset;
850 // int* dyn_dim1_size = source_array + meta_data_offset + sizeof(int);
851 std::vector<llvm::Value*> dynamic_dims;
852 for (int64 i = 1; i < pad_to_static.output().size(); ++i) {
853 // Dynamic size of each dimension is attached at the end of the source
854 // array(operand(0)). We need to extract these value.
855 const Shape& dim_shape = TypeToShape(pad_to_static.output()[i].getType());
856 TF_RET_CHECK(Shape::Equal()(dim_shape, ShapeUtil::MakeScalarShape(S32)));
857
858 const int64 dim_index = i - 1;
859 llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32(
860 b_.getInt8Ty(), raw_buffer, raw_data_size + dim_index * sizeof(int32));
861 llvm::Value* dyn_dim_size = b_.CreateLoad(
862 b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo()),
863 "dyn_dim_size");
864 dynamic_dims.push_back(dyn_dim_size);
865 }
866
867 // only one thread need to store the dynamic index
868 // int thread_id = GetThreadId();
869 // int block_id = GetBlockId();
870 // if (thread_id == 0 && block_id == 0) {
871 // *output[1] = *dyn_dim0_size;
872 // *output[2] = *dyn_dim1_size;
873 // }
874 KernelSupportLibrary{&b_}.If("is_thred_0", IsBlock0Thread0(&b_), [&] {
875 for (int64 i = 1; i < pad_to_static.output().size(); ++i) {
876 const int64 dim_index = i - 1;
877 llvm::Value* dest_dim_size_address =
878 output_dim_arrays[dim_index].GetBasePointer();
879 // output[i] stores dynamic_dim_(i-1)
880 b_.CreateStore(dynamic_dims[i - 1],
881 b_.CreateBitCast(dest_dim_size_address,
882 b_.getInt32Ty()->getPointerTo()));
883 }
884 });
885
886 // int dyn_element_total = 1;
887 // dyn_element_total *= *dyn_dim0_size;
888 // dyn_element_total *= *dyn_dim1_size;
889 llvm::Value* dyn_element_total = llvm::ConstantInt::get(b_.getInt32Ty(), 1);
890 for (llvm::Value* dynamic_dim : dynamic_dims) {
891 dyn_element_total = b_.CreateMul(dyn_element_total, dynamic_dim,
892 /*Name=*/"dyn_element_total");
893 }
894
895 // linear_index = block_id * threads_per_block + thread_id;
896 // if (linear_index < max_num_element) {
897 // Index static_index =
898 // delinerized(linerized_index, static_dim0_size, static_dim1_size);
899 // if (linerized_index < dyn_element_total) {
900 // Index dyn_index =
901 // delinerized(linerized_index, *dyn_dim0_size, *dyn_dim1_size);
902 // dest_array[dyn_index.dim0][dyn_index.dim1] =
903 // source_array[static_index.dim0][static_index.dim1];
904 // }
905 // }
906 llvm_ir::LoopEmitter::BodyEmitter body_generator =
907 [&](const llvm_ir::IrArray::Index& array_index) -> Status {
908 llvm::Value* linearIndex =
909 array_index.Linearize(input_shape.dimensions(), &b_);
910 auto if_in_dyn_bounds = llvm_ir::EmitIfThenElse(
911 b_.CreateICmpULT(linearIndex, dyn_element_total),
912 llvm_ir::IrName(ir_name, "in_dyn_bounds"), &b_, false);
913 // Set IR builder insertion point to the body of the if structure.
914 llvm_ir::SetToFirstInsertPoint(if_in_dyn_bounds.true_block, &b_);
915 llvm_ir::IrArray::Index dyn_index(linearIndex, input_shape,
916 absl::MakeSpan(dynamic_dims), &b_);
917 output_array.EmitWriteArrayElement(
918 dyn_index,
919 source_array.EmitReadArrayElement(array_index, &b_, /*name=*/""), &b_,
920 /*use_linear_index=*/false);
921 return Status::OK();
922 };
923
924 LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
925 input_shape, ir_emitter_context_->gpu_device_info(), unroll_factor);
926 UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
927 ir_emitter_context_->llvm_module());
928 TF_RETURN_IF_ERROR(
929 ParallelLoopEmitter(body_generator, data_shape, launch_dimensions, &b_,
930 unroll_factor)
931 .EmitLoop(ir_name,
932 GetIndexTypeForKernelFromMlir(
933 pad_to_static, launch_dimensions.launch_bound(), &b_)));
934 thunk_sequence_.emplace_back(std::move(kernel_thunk));
935 return Status::OK();
936 }
937
938 // Input = {dynamic array(with dynamic dimension meta data at the end)}
939 // Output = {static array, dynamic_dim0, dynamic_dim1}
EmitSliceToDynamicFromMlir(MlirEmitterInput mlir_input)940 Status IrEmitterUnnested::EmitSliceToDynamicFromMlir(
941 MlirEmitterInput mlir_input) {
942 // TODO(jurahul): Create an op to represent SliceToDynamic.
943 auto slice_to_dynamic = mlir::cast<mlir::lmhlo::CustomCallOp>(mlir_input.op);
944 int unroll_factor = 1;
945 std::string ir_name = mlir::GetNameFromLoc(slice_to_dynamic.getLoc());
946
947 std::vector<llvm_ir::IrArray> ir_arrays;
948 TF_ASSIGN_OR_RETURN(
949 auto kernel_thunk,
950 BuildKernelThunkForMlir(slice_to_dynamic, mlir_input.thunk_info,
951 mlir_input.extra_slice, &ir_arrays));
952
953 const Shape& input_shape =
954 TypeToShape(slice_to_dynamic.args().front().getType());
955 TF_RET_CHECK(slice_to_dynamic.output().size() == 1);
956 const Shape& data_shape =
957 TypeToShape(slice_to_dynamic.output().front().getType());
958
959 // TODO(jurahul): data_shape here is the static shape of the output (which has
960 // a dynamic shape in XLA). Currently, we are mapping that to a static shaped
961 // memref. When we change that to a more appropriate representation in MLIR,
962 // fix this code to correctly deduce the static shape backing the dynamically
963 // shaped memref.
964
965 // calculate the location where metadata needs to be inserted
966 // int* dyn_dim0_size = dest_array + meta_data_offset;
967 // int* dyn_dim1_size = dest_array + meta_data_offset + sizeof(int);
968 int32 raw_data_size = ShapeUtil::ByteSizeOf(data_shape);
969
970 // pseudo code for sliceToDynamic on a 2d array
971 // int* source_array = input[0];
972 // int* dest_array = output[0];
973 const llvm_ir::IrArray data_array = ir_arrays.back();
974 llvm::Value* dest_buffer = data_array.GetBasePointer();
975 llvm::Value* raw_buffer =
976 b_.CreateBitCast(dest_buffer, b_.getInt8Ty()->getPointerTo());
977
978 // Load dynamic dimensions from memory.
979 std::vector<llvm::Value*> dynamic_dims;
980 for (int64 i = 1; i < slice_to_dynamic.args().size(); ++i) {
981 // const int64 dim_index = i - 1;
982 llvm::Value* source_buffer = ir_arrays[i].GetBasePointer();
983 llvm::LoadInst* dyn_dim_size = b_.CreateLoad(source_buffer, "dyn_dim_size");
984 dynamic_dims.push_back(dyn_dim_size);
985 }
986
987 // only one thread need to store the dynamic index
988 // int thread_id = GetThreadId();
989 // int block_id = GetBlockId();
990 // if (thread_id == 0 && block_id == 0) {
991 // *dyn_dim0_size = *output[1];
992 // *dyn_dim1_size = *output[2];
993 // }
994 KernelSupportLibrary{&b_}.If("is_thred_0", IsBlock0Thread0(&b_), [&] {
995 for (int64 i = 1; i < slice_to_dynamic.args().size(); ++i) {
996 const int64 dim_index = i - 1;
997 llvm::Value* metadata = b_.CreateConstInBoundsGEP1_32(
998 b_.getInt8Ty(), raw_buffer,
999 raw_data_size + dim_index * sizeof(int32));
1000 // output[i] stores dynamic_dim_(i-1)
1001 b_.CreateStore(
1002 dynamic_dims[dim_index],
1003 b_.CreateBitCast(metadata, b_.getInt32Ty()->getPointerTo()));
1004 }
1005 });
1006
1007 // int dyn_element_total = 1;
1008 // dyn_element_total *= dyn_dim0_size;
1009 // dyn_element_total *= dyn_dim1_size;
1010 llvm::Value* dyn_element_total = llvm::ConstantInt::get(b_.getInt32Ty(), 1);
1011 for (llvm::Value* dynamic_dim : dynamic_dims) {
1012 dyn_element_total = b_.CreateMul(dyn_element_total, dynamic_dim,
1013 /*Name=*/"dyn_element_total");
1014 }
1015
1016 // linear_index = block_id * threads_per_block + thread_id;
1017 // if (linear_index < max_num_element) {
1018 // Index static_index =
1019 // delinerized(linerized_index, static_dim0_size, static_dim1_size);
1020 // if (linerized_index < dyn_element_total) {
1021 // Index dyn_index =
1022 // delinerized(linerized_index, *dyn_dim0_size, *dyn_dim1_size);
1023 // dest_array[static_index.dim0][static_index.di] =
1024 // source_array[dyn_index.dim0][dyn_index.dim1];
1025 // }
1026 // }
1027 llvm_ir::LoopEmitter::BodyEmitter body_generator =
1028 [&](const llvm_ir::IrArray::Index& array_index) -> Status {
1029 llvm::Value* linearIndex =
1030 array_index.Linearize(input_shape.dimensions(), &b_);
1031 auto if_in_dyn_bounds = llvm_ir::EmitIfThenElse(
1032 b_.CreateICmpULT(linearIndex, dyn_element_total),
1033 llvm_ir::IrName(ir_name, "in_dyn_bounds"), &b_, false);
1034 // Set IR builder insertion point to the body of the if structure.
1035 llvm_ir::SetToFirstInsertPoint(if_in_dyn_bounds.true_block, &b_);
1036 llvm_ir::IrArray::Index dyn_index(linearIndex, input_shape,
1037 absl::MakeSpan(dynamic_dims), &b_);
1038
1039 data_array.EmitWriteArrayElement(
1040 array_index,
1041 ir_arrays[0].EmitReadArrayElement(dyn_index, &b_, /*name=*/"",
1042 /*use_linear_index=*/false),
1043 &b_);
1044 return Status::OK();
1045 };
1046
1047 LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
1048 input_shape, ir_emitter_context_->gpu_device_info(), unroll_factor);
1049 UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
1050 ir_emitter_context_->llvm_module());
1051
1052 TF_RETURN_IF_ERROR(
1053 ParallelLoopEmitter(body_generator, data_shape, launch_dimensions, &b_,
1054 unroll_factor)
1055 .EmitLoop(ir_name, GetIndexTypeForKernelFromMlir(
1056 slice_to_dynamic,
1057 launch_dimensions.launch_bound(), &b_)));
1058 thunk_sequence_.emplace_back(std::move(kernel_thunk));
1059 return Status::OK();
1060 }
1061
HandleCustomCall(HloInstruction * custom_call)1062 Status IrEmitterUnnested::HandleCustomCall(HloInstruction* custom_call) {
1063 TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(custom_call));
1064 return EmitCustomCallFromMlir(input);
1065 }
1066
EmitCustomCallFromMlir(MlirEmitterInput input)1067 Status IrEmitterUnnested::EmitCustomCallFromMlir(MlirEmitterInput input) {
1068 using mlir::dyn_cast;
1069 using mlir::isa;
1070
1071 if (auto call = dyn_cast<mlir::lmhlo::CustomCallOp>(input.op)) {
1072 if (call.call_target_name() == "PadToStatic") {
1073 return EmitPadToStaticFromMlir(input);
1074 }
1075 if (call.call_target_name() == "SliceToDynamic") {
1076 return EmitSliceToDynamicFromMlir(input);
1077 }
1078 return EmitCustomCallThunkFromMlir(input);
1079 }
1080
1081 if (isa<mlir::lmhlo_gpu::GEMMOp, mlir::lmhlo_gpu::GEMM_BiasOp>(input.op)) {
1082 return EmitGemmThunkFromMlir(input);
1083 }
1084
1085 if (mlir::isa<mlir::lmhlo_gpu::ConvForwardOp,
1086 mlir::lmhlo_gpu::ConvForwardFusedOp,
1087 mlir::lmhlo_gpu::ConvForwardFusedSideInputOp,
1088 mlir::lmhlo_gpu::ConvBackwardFilterOp,
1089 mlir::lmhlo_gpu::ConvBackwardInputOp>(input.op)) {
1090 return EmitConvolutionThunkFromMlir(input);
1091 }
1092
1093 if (isa<mlir::lmhlo_gpu::BatchNormTrainingOp,
1094 mlir::lmhlo_gpu::BatchNormInferenceOp,
1095 mlir::lmhlo_gpu::BatchNormGradOp>(input.op)) {
1096 return EmitBatchNormThunkFromMlir(input);
1097 }
1098
1099 #if GOOGLE_CUDA
1100 if (mlir::isa<mlir::lmhlo_gpu::CholeskyOp>(input.op)) {
1101 return EmitCholeskyThunkFromMlir(input);
1102 }
1103 #endif // GOOGLE_CUDA
1104
1105 return Unimplemented("No registered implementation for custom call to \"%s\"",
1106 MlirToString(input.op));
1107 }
1108
EmitConvolutionThunkFromMlir(MlirEmitterInput input)1109 Status IrEmitterUnnested::EmitConvolutionThunkFromMlir(MlirEmitterInput input) {
1110 using mlir::dyn_cast;
1111 using mlir::lmhlo_gpu::Activation;
1112 using mlir::lmhlo_gpu::ConvBackwardFilterOp;
1113 using mlir::lmhlo_gpu::ConvBackwardInputOp;
1114 using mlir::lmhlo_gpu::ConvForwardFusedOp;
1115 using mlir::lmhlo_gpu::ConvForwardFusedSideInputOp;
1116 using mlir::lmhlo_gpu::ConvForwardOp;
1117
1118 // Last 2 operands of the convolution operation are the result and scratch.
1119 std::vector<BufferAllocation::Slice> operand_slices;
1120 int64 num_operands = input.op->getNumOperands();
1121 operand_slices.reserve(num_operands - 2);
1122 for (mlir::Value operand : input.op->getOperands().drop_back(2)) {
1123 TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSliceForMlir(operand));
1124 operand_slices.push_back(slice);
1125 }
1126
1127 mlir::Value conv_result = input.op->getOperand(num_operands - 2);
1128 mlir::Value scratch_result = input.op->getOperand(num_operands - 1);
1129 TF_ASSIGN_OR_RETURN(auto conv_result_slice,
1130 GetAllocationSliceForMlir(conv_result));
1131 TF_ASSIGN_OR_RETURN(auto scratch_slice,
1132 GetAllocationSliceForMlir(scratch_result));
1133
1134 auto apply_layout = [](const Shape& shape, mlir::ArrayAttr layout_attrib) {
1135 mlir::SmallVector<int64, 4> minor_to_major = llvm::to_vector<4>(
1136 llvm::map_range(layout_attrib, [](mlir::Attribute a) -> int64 {
1137 return static_cast<int64>(a.cast<mlir::IntegerAttr>().getInt());
1138 }));
1139 return ShapeUtil::MakeShapeWithLayout(shape.element_type(),
1140 shape.dimensions(), minor_to_major);
1141 };
1142
1143 GpuConvDescriptor descriptor;
1144
1145 auto fill_conv_descriptor = [&](auto op) {
1146 descriptor.operand0_shape =
1147 apply_layout(TypeToShape(input.op->getOperand(0).getType()),
1148 op.backend_config().operand_0_layout());
1149 descriptor.operand1_shape =
1150 apply_layout(TypeToShape(input.op->getOperand(1).getType()),
1151 op.backend_config().operand_1_layout());
1152 descriptor.result_shape = apply_layout(TypeToShape(conv_result.getType()),
1153 op.backend_config().result_layout());
1154 descriptor.dnums = ConvertConvDimensionNumbers(op.dimension_numbers());
1155 descriptor.scratch_size =
1156 input.extra_slice->shape.tuple_shapes(1).dimensions(0);
1157 mlir::DenseIntElementsAttr window_strides = op.window_strides().getValue();
1158 mlir::DenseIntElementsAttr padding = op.padding().getValue();
1159 mlir::DenseIntElementsAttr lhs_dilation = op.lhs_dilation().getValue();
1160 mlir::DenseIntElementsAttr rhs_dilation = op.rhs_dilation().getValue();
1161 mlir::DenseElementsAttr window_reversal = op.window_reversal().getValue();
1162 for (auto index : llvm::seq<int>(0, window_strides.getNumElements())) {
1163 WindowDimension* dim = descriptor.window.add_dimensions();
1164 // Window size for a convolution is the same as the kernel size.
1165 // Kernel size of the convolution is operand1_shape. We need to look at
1166 // the convolution dimension numbers kernel spatial dimensions to get
1167 // the window size.
1168 int kernel_dim = descriptor.dnums.kernel_spatial_dimensions(index);
1169 dim->set_size(descriptor.operand0_shape.dimensions(kernel_dim));
1170 dim->set_stride(window_strides.getValue<int64>(index));
1171 dim->set_padding_low(padding.getValue<int64>(index));
1172 dim->set_padding_high(padding.getValue<int64>(index));
1173 dim->set_base_dilation(lhs_dilation.getValue<int64>(index));
1174 dim->set_window_dilation(rhs_dilation.getValue<int64>(index));
1175 dim->set_window_reversal(window_reversal.getValue<bool>(index));
1176 }
1177 descriptor.feature_group_count = op.feature_group_count();
1178 descriptor.backend_config.set_algorithm(
1179 op.backend_config().algorithm().getInt());
1180 descriptor.backend_config.set_tensor_ops_enabled(
1181 op.backend_config().tensor_ops_enabled().getValue());
1182 descriptor.backend_config.set_conv_result_scale(
1183 op.result_scale().convertToDouble());
1184 };
1185
1186 auto set_activation_mode = [&](auto op) -> Status {
1187 TF_ASSIGN_OR_RETURN(stream_executor::dnn::ActivationMode activation_mode,
1188 ConvertConvActivationMode(op.activation_mode()));
1189 descriptor.backend_config.set_activation_mode(
1190 static_cast<int64>(activation_mode));
1191 return Status::OK();
1192 };
1193
1194 if (auto op = dyn_cast<ConvForwardOp>(input.op)) {
1195 descriptor.kind = CudnnConvKind::kForward;
1196 fill_conv_descriptor(op);
1197 } else if (auto op = dyn_cast<ConvBackwardInputOp>(input.op)) {
1198 descriptor.kind = CudnnConvKind::kBackwardInput;
1199 fill_conv_descriptor(op);
1200 } else if (auto op = dyn_cast<ConvBackwardFilterOp>(input.op)) {
1201 descriptor.kind = CudnnConvKind::kBackwardFilter;
1202 fill_conv_descriptor(op);
1203 } else if (auto op = dyn_cast<ConvForwardFusedOp>(input.op)) {
1204 descriptor.kind = CudnnConvKind::kForwardActivation;
1205 fill_conv_descriptor(op);
1206 TF_RETURN_IF_ERROR(set_activation_mode(op));
1207 } else if (auto op = dyn_cast<ConvForwardFusedSideInputOp>(input.op)) {
1208 descriptor.kind = CudnnConvKind::kForwardActivation;
1209 fill_conv_descriptor(op);
1210 TF_RETURN_IF_ERROR(set_activation_mode(op));
1211 descriptor.backend_config.set_side_input_scale(
1212 op.side_input_scale().convertToDouble());
1213 } else {
1214 return InternalError("Unexpected operation");
1215 }
1216 TF_ASSIGN_OR_RETURN(GpuConvConfig config, GetGpuConvConfig(descriptor, ""));
1217 AddThunkToThunkSequence(absl::make_unique<ConvolutionThunk>(
1218 input.thunk_info, std::move(config), std::move(operand_slices),
1219 conv_result_slice, scratch_slice));
1220 return Status::OK();
1221 }
1222
EmitGemmThunkFromMlir(MlirEmitterInput input)1223 Status IrEmitterUnnested::EmitGemmThunkFromMlir(MlirEmitterInput input) {
1224 auto build_gemm_config = [](auto op) {
1225 GpuGemmConfig config;
1226 GemmBackendConfig& backend = config.backend_config;
1227 config.output_shape = TypeToShape(op.output().getType());
1228 config.lhs_shape = TypeToShape(op.lhs().getType());
1229 config.rhs_shape = TypeToShape(op.rhs().getType());
1230 backend.Clear();
1231 if (op.algorithm()) {
1232 backend.set_selected_algorithm(*op.algorithm());
1233 }
1234 backend.set_alpha_real(op.alpha_real().convertToDouble());
1235 backend.set_alpha_imag(op.alpha_imag().convertToDouble());
1236 backend.set_batch_size(op.batch_size());
1237
1238 auto& dims = *backend.mutable_dot_dimension_numbers();
1239 auto mlir_dims = op.dot_dimension_numbers();
1240
1241 auto fill_dims = [](mlir::DenseElementsAttr mlir_dim, auto* config_attrs) {
1242 for (llvm::APInt e : mlir_dim.getIntValues())
1243 config_attrs->Add(e.getSExtValue());
1244 };
1245 fill_dims(mlir_dims.lhs_batching_dimensions(),
1246 dims.mutable_lhs_batch_dimensions());
1247 fill_dims(mlir_dims.rhs_batching_dimensions(),
1248 dims.mutable_rhs_batch_dimensions());
1249 fill_dims(mlir_dims.lhs_contracting_dimensions(),
1250 dims.mutable_lhs_contracting_dimensions());
1251 fill_dims(mlir_dims.rhs_contracting_dimensions(),
1252 dims.mutable_rhs_contracting_dimensions());
1253 return config;
1254 };
1255
1256 GpuGemmConfig config;
1257 BufferAllocation::Slice lhs, rhs, bias, output;
1258
1259 if (auto gemm = mlir::dyn_cast<mlir::lmhlo_gpu::GEMMOp>(input.op)) {
1260 config = build_gemm_config(gemm);
1261 TF_ASSIGN_OR_RETURN(lhs, GetAllocationSliceForMlir(gemm.lhs()));
1262 TF_ASSIGN_OR_RETURN(rhs, GetAllocationSliceForMlir(gemm.rhs()));
1263 TF_ASSIGN_OR_RETURN(output, GetAllocationSliceForMlir(gemm.output()));
1264 } else if (auto gemm_bias =
1265 mlir::dyn_cast<mlir::lmhlo_gpu::GEMM_BiasOp>(input.op)) {
1266 config = build_gemm_config(gemm_bias);
1267 config.backend_config.set_beta(gemm_bias.beta().convertToDouble());
1268 TF_ASSIGN_OR_RETURN(lhs, GetAllocationSliceForMlir(gemm_bias.lhs()));
1269 TF_ASSIGN_OR_RETURN(rhs, GetAllocationSliceForMlir(gemm_bias.rhs()));
1270 TF_ASSIGN_OR_RETURN(bias, GetAllocationSliceForMlir(gemm_bias.bias()));
1271 TF_ASSIGN_OR_RETURN(output, GetAllocationSliceForMlir(gemm_bias.output()));
1272
1273 // The bias is passed inside the output buffer. If those buffers are shared
1274 // we can just use it, otherwise copy the bias values into the output buffer
1275 // first.
1276 if (bias != output) {
1277 std::vector<std::unique_ptr<Thunk>> thunks;
1278
1279 thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
1280 Thunk::ThunkInfo(),
1281 /*source_buffer=*/bias,
1282 /*destination_buffer=*/output,
1283 /*mem_size=*/ShapeUtil::ByteSizeOf(config.output_shape)));
1284 thunks.push_back(absl::make_unique<GemmThunk>(
1285 input.thunk_info, std::move(config), lhs, rhs, output,
1286 /*implements_whole_instruction=*/false));
1287 AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
1288 input.thunk_info, std::move(thunks)));
1289 return Status::OK();
1290 }
1291 }
1292
1293 AddThunkToThunkSequence(absl::make_unique<GemmThunk>(
1294 input.thunk_info, std::move(config), lhs, rhs, output,
1295 /*implements_whole_instruction=*/true));
1296 return Status::OK();
1297 }
1298
1299 namespace {
1300 // An MLIR value and its name as defined in the ODS spec.
1301 struct NamedValue {
1302 mlir::Value value;
1303 absl::string_view name;
1304 };
1305
1306 // Verifies that the given batch norm is well formed for thunk emission. This
1307 // requires that all statistics operands (mean, stddev etc) are F32 types and
1308 // all the non-statistics operands need to match in shape, element type, and
1309 // layout (which maps to them having the same memref type).
VerifyBatchNormForThunkEmission(mlir::ArrayRef<NamedValue> statistics_operands,mlir::ArrayRef<NamedValue> other_operands)1310 Status VerifyBatchNormForThunkEmission(
1311 mlir::ArrayRef<NamedValue> statistics_operands,
1312 mlir::ArrayRef<NamedValue> other_operands) {
1313 for (const NamedValue& v : statistics_operands) {
1314 // Note: MLIR verification will ensure that the operands of the batchnorm
1315 // LHLO are valid memref types.
1316 if (!v.value.getType().cast<mlir::MemRefType>().getElementType().isF32()) {
1317 return Unimplemented("Operand %s of batch norm should have F32 type",
1318 v.name);
1319 }
1320 }
1321 if (other_operands.empty()) {
1322 return Status::OK();
1323 }
1324
1325 mlir::Type first_type = other_operands.front().value.getType();
1326 absl::string_view first_name = other_operands.front().name;
1327
1328 for (const NamedValue& v : other_operands.drop_front(1)) {
1329 if (v.value.getType() != first_type) {
1330 return Unimplemented("%s and %s for batch norm should have same types",
1331 v.name, first_name);
1332 }
1333 }
1334
1335 return Status::OK();
1336 }
1337 } // namespace
1338
EmitBatchNormThunkFromMlir(MlirEmitterInput input)1339 Status IrEmitterUnnested::EmitBatchNormThunkFromMlir(MlirEmitterInput input) {
1340 auto get_batch_norm_config = [](auto op, mlir::Value output) {
1341 CudnnBatchNormConfig config;
1342 config.output_shape = TypeToShape(output.getType());
1343 config.output_type = config.output_shape.element_type();
1344 config.epsilon = op.epsilon().convertToFloat();
1345 config.feature_index = op.feature_index();
1346 return config;
1347 };
1348
1349 // The statistics operands for batch norm operations need to be FP32 type.
1350 // And the rest of the operands need match in shape, layout, and element type
1351 // to match.
1352 if (auto bn_train =
1353 mlir::dyn_cast<mlir::lmhlo_gpu::BatchNormTrainingOp>(input.op)) {
1354 TF_RETURN_IF_ERROR(VerifyBatchNormForThunkEmission(
1355 /*statistics_operands=*/
1356 {{bn_train.scale(), "scale"},
1357 {bn_train.offset(), "offset"},
1358 {bn_train.batch_mean(), "batch_mean"},
1359 {bn_train.batch_stddev(), "batch_stddev"}},
1360 /*other_operands=*/
1361 {{bn_train.operand(), "operand"}, {bn_train.output(), "output"}}));
1362 TF_ASSIGN_OR_RETURN(auto operand,
1363 GetAllocationSliceForMlir(bn_train.operand()));
1364 TF_ASSIGN_OR_RETURN(auto scale,
1365 GetAllocationSliceForMlir(bn_train.scale()));
1366 TF_ASSIGN_OR_RETURN(auto offset,
1367 GetAllocationSliceForMlir(bn_train.offset()));
1368
1369 // BatchNormTraining returns a tuple of three elements: data, calculated
1370 // mean, and calculated 1/sqrt(variance + epsilon).
1371 TF_ASSIGN_OR_RETURN(auto output_data,
1372 GetAllocationSliceForMlir(bn_train.output()));
1373 TF_ASSIGN_OR_RETURN(auto output_mean,
1374 GetAllocationSliceForMlir(bn_train.batch_mean()));
1375 TF_ASSIGN_OR_RETURN(auto output_inv_stddev,
1376 GetAllocationSliceForMlir(bn_train.batch_stddev()));
1377
1378 AddThunkToThunkSequence(
1379 absl::make_unique<CudnnBatchNormForwardTrainingThunk>(
1380 input.thunk_info,
1381 /*config=*/get_batch_norm_config(bn_train, bn_train.output()),
1382 /*operand=*/operand,
1383 /*scale=*/scale,
1384 /*offset=*/offset,
1385 /*output_data=*/output_data,
1386 /*output_mean=*/output_mean,
1387 /*output_inv_stddev=*/output_inv_stddev));
1388 return Status::OK();
1389 }
1390
1391 if (auto bn_grad =
1392 mlir::dyn_cast<mlir::lmhlo_gpu::BatchNormGradOp>(input.op)) {
1393 TF_RETURN_IF_ERROR(VerifyBatchNormForThunkEmission(
1394 /*statistics_operands=*/
1395 {{bn_grad.scale(), "scale"},
1396 {bn_grad.mean(), "mean"},
1397 {bn_grad.stddev(), "stddev"},
1398 {bn_grad.grad_scale(), "grad_scale"},
1399 {bn_grad.grad_offset(), "grad_offset"}},
1400 /*other_operands=*/
1401 {{bn_grad.operand(), "operand"},
1402 {bn_grad.grad_output(), "grad_output"},
1403 {bn_grad.grad_operand(), "grad_operand"}}));
1404
1405 TF_ASSIGN_OR_RETURN(auto operand,
1406 GetAllocationSliceForMlir(bn_grad.operand()));
1407 TF_ASSIGN_OR_RETURN(auto scale, GetAllocationSliceForMlir(bn_grad.scale()));
1408 TF_ASSIGN_OR_RETURN(auto mean, GetAllocationSliceForMlir(bn_grad.mean()));
1409 TF_ASSIGN_OR_RETURN(auto inv_stddev,
1410 GetAllocationSliceForMlir(bn_grad.stddev()));
1411 TF_ASSIGN_OR_RETURN(auto grad_output,
1412 GetAllocationSliceForMlir(bn_grad.grad_output()));
1413
1414 // BatchNormGrad returns a tuple of three elements: grad_data, grad_scale,
1415 // grad_offset.
1416 TF_ASSIGN_OR_RETURN(auto output_grad_data,
1417 GetAllocationSliceForMlir(bn_grad.grad_operand()));
1418 TF_ASSIGN_OR_RETURN(auto output_grad_scale,
1419 GetAllocationSliceForMlir(bn_grad.grad_scale()));
1420 TF_ASSIGN_OR_RETURN(auto output_grad_offset,
1421 GetAllocationSliceForMlir(bn_grad.grad_offset()));
1422
1423 CudnnBatchNormConfig config;
1424 config.output_shape = TypeToShape(bn_grad.grad_output().getType());
1425 config.output_type = config.output_shape.element_type();
1426 config.epsilon = bn_grad.epsilon().convertToFloat();
1427 config.feature_index = bn_grad.feature_index();
1428
1429 AddThunkToThunkSequence(absl::make_unique<CudnnBatchNormBackwardThunk>(
1430 input.thunk_info,
1431 /*config=*/get_batch_norm_config(bn_grad, bn_grad.grad_output()),
1432 /*operand=*/operand,
1433 /*scale=*/scale,
1434 /*mean=*/mean,
1435 /*inv_stddev=*/inv_stddev,
1436 /*grad_output=*/grad_output,
1437 /*output_grad_data=*/output_grad_data,
1438 /*output_grad_scale=*/output_grad_scale,
1439 /*output_grad_offset=*/output_grad_offset));
1440 return Status::OK();
1441 }
1442
1443 if (auto bn_inference =
1444 mlir::dyn_cast<mlir::lmhlo_gpu::BatchNormInferenceOp>(input.op)) {
1445 TF_RETURN_IF_ERROR(
1446 VerifyBatchNormForThunkEmission(/*statistics_operands=*/
1447 {{bn_inference.scale(), "scale"},
1448 {bn_inference.offset(), "offset"},
1449 {bn_inference.mean(), "mean"},
1450 {bn_inference.stddev(), "stddev"}},
1451 /*other_operands=*/
1452 {{bn_inference.operand(), "operand"},
1453 {bn_inference.output(), "output"}}));
1454
1455 TF_ASSIGN_OR_RETURN(auto operand,
1456 GetAllocationSliceForMlir(bn_inference.operand()));
1457 TF_ASSIGN_OR_RETURN(auto scale,
1458 GetAllocationSliceForMlir(bn_inference.scale()));
1459 TF_ASSIGN_OR_RETURN(auto offset,
1460 GetAllocationSliceForMlir(bn_inference.offset()));
1461 TF_ASSIGN_OR_RETURN(auto mean,
1462 GetAllocationSliceForMlir(bn_inference.mean()));
1463 TF_ASSIGN_OR_RETURN(auto variance,
1464 GetAllocationSliceForMlir(bn_inference.stddev()));
1465 TF_ASSIGN_OR_RETURN(auto output,
1466 GetAllocationSliceForMlir(bn_inference.output()));
1467
1468 AddThunkToThunkSequence(absl::make_unique<
1469 CudnnBatchNormForwardInferenceThunk>(
1470 input.thunk_info,
1471 /*config=*/get_batch_norm_config(bn_inference, bn_inference.output()),
1472 /*operand=*/operand,
1473 /*scale=*/scale,
1474 /*offset=*/offset,
1475 /*mean=*/mean,
1476 /*variance=*/variance,
1477 /*output=*/output));
1478 return Status::OK();
1479 }
1480
1481 return Unimplemented("Unsupported batch norm operation");
1482 }
1483
1484 #if GOOGLE_CUDA
EmitCholeskyThunkFromMlir(MlirEmitterInput input)1485 Status IrEmitterUnnested::EmitCholeskyThunkFromMlir(MlirEmitterInput input) {
1486 auto cholesky_op = mlir::cast<mlir::lmhlo_gpu::CholeskyOp>(input.op);
1487
1488 const Shape shape = TypeToShape(cholesky_op.input().getType());
1489 int ndim = shape.dimensions_size();
1490 CHECK_GE(ndim, 2);
1491 int64 n = shape.dimensions(ndim - 1);
1492
1493 const auto& dims = shape.dimensions();
1494 int64 batch_size = std::accumulate(dims.begin(), dims.end() - 2, int64{1},
1495 [](int64 a, int64 b) { return a * b; });
1496
1497 TF_ASSIGN_OR_RETURN(auto operand_buffer,
1498 GetAllocationSliceForMlir(cholesky_op.input()));
1499 TF_ASSIGN_OR_RETURN(auto a_buffer,
1500 GetAllocationSliceForMlir(cholesky_op.output()));
1501 TF_ASSIGN_OR_RETURN(auto workspace_buffer,
1502 GetAllocationSliceForMlir(cholesky_op.scratch()));
1503 TF_ASSIGN_OR_RETURN(auto info_buffer,
1504 GetAllocationSliceForMlir(cholesky_op.info()));
1505
1506 std::vector<std::unique_ptr<Thunk>> thunks;
1507
1508 if (operand_buffer != a_buffer) {
1509 thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
1510 input.thunk_info,
1511 /*source_address=*/operand_buffer,
1512 /*destination_buffer=*/a_buffer,
1513 /*mem_size=*/ShapeUtil::ByteSizeOf(shape)));
1514 }
1515
1516 CholeskyOptions options;
1517 options.set_lower(cholesky_op.is_lower());
1518 thunks.push_back(absl::make_unique<CholeskyThunk>(
1519 input.thunk_info, options, a_buffer, workspace_buffer, info_buffer,
1520 shape.element_type(), batch_size, n));
1521
1522 // Elide the sequential thunk if there's no copy.
1523 if (thunks.size() == 1) {
1524 AddThunkToThunkSequence(std::move(thunks[0]));
1525 } else {
1526 AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
1527 input.thunk_info, std::move(thunks)));
1528 }
1529
1530 return Status::OK();
1531 }
1532 #endif // GOOGLE_CUDA
1533
EmitCustomCallThunkFromMlir(MlirEmitterInput input)1534 Status IrEmitterUnnested::EmitCustomCallThunkFromMlir(MlirEmitterInput input) {
1535 auto custom_call = mlir::cast<mlir::lmhlo::CustomCallOp>(input.op);
1536 const std::string call_target_name = custom_call.call_target_name().str();
1537
1538 void* call_target = CustomCallTargetRegistry::Global()->Lookup(
1539 call_target_name, std::string(platform_name()));
1540 if (call_target) {
1541 std::vector<BufferAllocation::Slice> operands;
1542 for (mlir::Value arg : custom_call.args()) {
1543 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice arg_slice,
1544 GetAllocationSliceForMlir(arg));
1545 operands.push_back(arg_slice);
1546 }
1547
1548 std::vector<BufferAllocation::Slice> results;
1549 for (mlir::Value output : custom_call.output()) {
1550 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
1551 GetAllocationSliceForMlir(output));
1552 results.push_back(output_slice);
1553 }
1554
1555 AddThunkToThunkSequence(absl::make_unique<CustomCallThunk>(
1556 input.thunk_info, call_target, std::move(operands), std::move(results),
1557 custom_call.backend_config().str()));
1558 return Status::OK();
1559 }
1560 return Unimplemented("No registered implementation for custom call to \"%s\"",
1561 call_target_name);
1562 }
1563
HandleFft(HloInstruction * fft)1564 Status IrEmitterUnnested::HandleFft(HloInstruction* fft) {
1565 TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(fft));
1566 return EmitFftThunkFromMlir(input);
1567 }
1568
EmitFftThunkFromMlir(MlirEmitterInput input)1569 Status IrEmitterUnnested::EmitFftThunkFromMlir(MlirEmitterInput input) {
1570 auto fft_op = mlir::cast<mlir::lmhlo::FftOp>(input.op);
1571 const Shape operand_shape = TypeToShape(fft_op.operand().getType());
1572 const Shape output_shape = TypeToShape(fft_op.output().getType());
1573 TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(operand_shape.layout()));
1574 TF_RET_CHECK(LayoutUtil::IsMonotonicWithDim0Major(output_shape.layout()));
1575
1576 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice arg_slice,
1577 GetAllocationSliceForMlir(fft_op.operand()));
1578 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice dest_slice,
1579 GetAllocationSliceForMlir(fft_op.output()));
1580 TF_ASSIGN_OR_RETURN(xla::FftType fft_type, ConvertFftType(fft_op.fft_type()));
1581 auto fft_length_values = fft_op.fft_length().getValues<int64>();
1582 std::vector<int64> fft_length(fft_length_values.begin(),
1583 fft_length_values.end());
1584 AddThunkToThunkSequence(
1585 absl::make_unique<FftThunk>(input.thunk_info, fft_type, fft_length,
1586 /*input_buffer=*/arg_slice,
1587 /*output_buffer=*/dest_slice,
1588 /*input_shape=*/operand_shape,
1589 /*output_shape=*/output_shape));
1590 return Status::OK();
1591 }
1592
HandleTriangularSolve(HloInstruction * hlo)1593 Status IrEmitterUnnested::HandleTriangularSolve(HloInstruction* hlo) {
1594 TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
1595 return EmitTriangularSolveFromMlir(input);
1596 }
1597
EmitTriangularSolveFromMlir(MlirEmitterInput input)1598 Status IrEmitterUnnested::EmitTriangularSolveFromMlir(MlirEmitterInput input) {
1599 auto triangular_solve_op =
1600 mlir::cast<mlir::lmhlo::TriangularSolveOp>(input.op);
1601 auto has_fortran_layout = [](mlir::DenseIntElementsAttr layout_attr) {
1602 int64_t n = layout_attr.getNumElements();
1603 return layout_attr.getValue<int64_t>({0}) == n - 2 &&
1604 layout_attr.getValue<int64_t>({1}) == n - 1;
1605 };
1606 TF_RET_CHECK(has_fortran_layout(triangular_solve_op.layout_a()));
1607 TF_RET_CHECK(has_fortran_layout(triangular_solve_op.layout_b()));
1608 TF_RET_CHECK(has_fortran_layout(triangular_solve_op.layout_output()));
1609
1610 const Shape b_shape = TypeToShape(triangular_solve_op.b().getType());
1611
1612 const Shape output_shape =
1613 TypeToShape(triangular_solve_op.output().getType());
1614
1615 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice a_slice,
1616 GetAllocationSliceForMlir(triangular_solve_op.a()));
1617 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice b_slice,
1618 GetAllocationSliceForMlir(triangular_solve_op.b()));
1619 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice output_slice,
1620 GetAllocationSliceForMlir(triangular_solve_op.output()));
1621 TF_ASSIGN_OR_RETURN(TriangularSolveOptions_Transpose transpose_a,
1622 ConvertTranspose(triangular_solve_op.transpose_a()));
1623
1624 std::vector<std::unique_ptr<Thunk>> thunks;
1625
1626 // Triangular solve is in-place on 'b', so copy 'b' to the output if they
1627 // aren't the same buffer.
1628 if (b_slice != output_slice) {
1629 thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
1630 Thunk::ThunkInfo(),
1631 /*source_address=*/b_slice,
1632 /*destination_buffer=*/output_slice,
1633 /*mem_size=*/ShapeUtil::ByteSizeOf(b_shape)));
1634 }
1635
1636 int64 m = b_shape.dimensions(b_shape.rank() - 2);
1637 int64 n = b_shape.dimensions(b_shape.rank() - 1);
1638 int64 batch_size = std::accumulate(b_shape.dimensions().begin(),
1639 b_shape.dimensions().end() - 2, int64{1},
1640 [](int64 a, int64 b) { return a * b; });
1641 int64 elem_size =
1642 ShapeUtil::ByteSizeOfPrimitiveType(output_shape.element_type());
1643 int64 a_batch_stride =
1644 triangular_solve_op.left_side() ? m * m * elem_size : n * n * elem_size;
1645 int64 b_batch_stride = m * n * elem_size;
1646 TriangularSolveOptions options;
1647 options.set_left_side(triangular_solve_op.left_side());
1648 options.set_lower(triangular_solve_op.lower());
1649 options.set_unit_diagonal(triangular_solve_op.unit_diagonal());
1650 options.set_transpose_a(transpose_a);
1651 thunks.push_back(absl::make_unique<TriangularSolveThunk>(
1652 input.thunk_info, options,
1653 /*a_input_buffer=*/a_slice,
1654 /*b_input_buffer=*/output_slice, output_shape.element_type(), batch_size,
1655 m, n, a_batch_stride, b_batch_stride));
1656
1657 // Elide the sequential thunk if there's no copy.
1658 if (thunks.size() == 1) {
1659 AddThunkToThunkSequence(std::move(thunks[0]));
1660 } else {
1661 AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
1662 input.thunk_info, std::move(thunks)));
1663 }
1664 return Status::OK();
1665 }
1666
1667 // Convert the following form of fusion region:
1668 // fusion() {
1669 // %0 = tensor_load %external_memref0
1670 // %1 = tensor_load %external_memref1
1671 // ...
1672 // tensor_store %ret, %external_memref2
1673 // }
1674 // to
1675 // fusion(%external_memref0, %external_memref1) (^bb(%0, %1) {
1676 // ...
1677 // mhlo.return %ret
1678 // })
1679 //
1680 // So that it's suitable for MHLO -> XLA HLO conversion.
1681 // This function won't be needed once ElementalIrEmitter migrates to take MHLO
1682 // instead.
ProcessFusionForConversion(mlir::Region * region,std::vector<Shape> * operand_shapes,std::vector<Shape> * output_shapes)1683 static Status ProcessFusionForConversion(mlir::Region* region,
1684 std::vector<Shape>* operand_shapes,
1685 std::vector<Shape>* output_shapes) {
1686 std::vector<mlir::TensorLoadOp> loads;
1687 std::vector<mlir::TensorStoreOp> stores;
1688
1689 region->walk([&](mlir::TensorLoadOp load) {
1690 if (load.memref().getParentRegion() != region) {
1691 loads.push_back(load);
1692 }
1693 });
1694
1695 region->walk([&](mlir::TensorStoreOp store) {
1696 if (store.memref().getParentRegion() != region) {
1697 stores.push_back(store);
1698 }
1699 });
1700
1701 for (auto load : loads) {
1702 auto arg = region->addArgument(load.getType());
1703 load.replaceAllUsesWith(arg);
1704 Shape shape = TypeToShape(load.getType());
1705 if (auto attr = mlir::GetLayoutFromMlirHlo(load)) {
1706 std::vector<int64> minor_to_major;
1707 absl::c_transform(
1708 attr, std::back_inserter(minor_to_major),
1709 std::function<int64(const llvm::APInt&)>(&llvm::APInt::getZExtValue));
1710 *shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
1711 } else {
1712 *shape.mutable_layout() =
1713 LayoutUtil::MakeDescendingLayout(load.getType().getShape().size());
1714 }
1715 operand_shapes->push_back(std::move(shape));
1716 load.erase();
1717 }
1718
1719 std::vector<mlir::Value> returned_values;
1720 for (auto store : stores) {
1721 Shape shape = TypeToShape(store.memref().getType());
1722 if (auto attr = mlir::GetLayoutFromMlirHlo(store)) {
1723 std::vector<int64> minor_to_major;
1724 absl::c_transform(
1725 attr, std::back_inserter(minor_to_major),
1726 std::function<int64(const llvm::APInt&)>(&llvm::APInt::getZExtValue));
1727 *shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
1728 }
1729 output_shapes->push_back(shape);
1730
1731 returned_values.push_back(store.tensor());
1732 store.erase();
1733 }
1734
1735 region->back().back().erase();
1736 auto b = mlir::OpBuilder::atBlockEnd(®ion->back());
1737 auto loc = returned_values[0].getLoc();
1738 b.create<mlir::mhlo::ReturnOp>(loc, returned_values);
1739 return Status::OK();
1740 }
1741
GetMlirEmitterInput(HloInstruction * hlo)1742 StatusOr<MlirEmitterInput> IrEmitterUnnested::GetMlirEmitterInput(
1743 HloInstruction* hlo) {
1744 MlirEmitterInput input;
1745 TF_ASSIGN_OR_RETURN(input.op, lhlo_scratch_emitter_->EmitOp(hlo));
1746 input.thunk_info = GetThunkInfo(hlo);
1747 if (hlo->shape().IsTuple()) {
1748 const auto& buffer_assignment = ir_emitter_context_->buffer_assignment();
1749 auto& slice = input.extra_slice.emplace();
1750 TF_ASSIGN_OR_RETURN(slice.buffer_slice,
1751 buffer_assignment.GetUniqueSlice(hlo, {}));
1752 slice.written = true;
1753 slice.shape = hlo->shape();
1754 }
1755 return input;
1756 }
1757
1758 // TODO(timshen): update the comment once the HandleFusion code path deleted.
1759 //
1760 // This is migrated from IrEmitter::HandleFusion() with IrEmitterUnnested as the
1761 // subclass. The logic is de-virtualized and less scattered.
EmitLoopFusionFromMlir(MlirEmitterInput input,const Shape & output_shape,absl::optional<int> unroll_factor_override)1762 Status IrEmitterUnnested::EmitLoopFusionFromMlir(
1763 MlirEmitterInput input, const Shape& output_shape,
1764 absl::optional<int> unroll_factor_override) {
1765 auto fusion = mlir::cast<mlir::lmhlo::FusionOp>(input.op);
1766 MlirEmitterContext context;
1767 context.SetOperation(fusion);
1768
1769 std::vector<llvm_ir::IrArray> ir_arrays;
1770 Thunk* kernel_thunk;
1771 {
1772 TF_ASSIGN_OR_RETURN(std::unique_ptr<KernelThunk> kernel_thunk_ptr,
1773 BuildKernelThunkForMlir(fusion, input.thunk_info,
1774 input.extra_slice, &ir_arrays));
1775 kernel_thunk = kernel_thunk_ptr.get();
1776 thunk_sequence_.emplace_back(std::move(kernel_thunk_ptr));
1777 }
1778
1779 auto operand_arrays =
1780 absl::MakeSpan(ir_arrays).subspan(0, context.operand_shapes.size());
1781 auto output_element_arrays = absl::MakeSpan(ir_arrays).subspan(
1782 context.operand_shapes.size(), context.output_shapes.size());
1783 const llvm_ir::IrArray* tuple_output_array = nullptr;
1784 if (ir_arrays.size() ==
1785 context.operand_shapes.size() + context.output_shapes.size() + 1) {
1786 tuple_output_array = &ir_arrays[context.operand_shapes.size() +
1787 context.output_shapes.size()];
1788 }
1789
1790 TF_ASSIGN_OR_RETURN(const HloComputation* fused_computation,
1791 GetOrCreateSubComputationFromRegion(&fusion.region(),
1792 /*is_fusion=*/true));
1793
1794 GpuElementalIrEmitter elemental_emitter(hlo_module_config_, module_, &b_,
1795 GetNestedComputer());
1796 FusedIrEmitter fused_emitter(&elemental_emitter);
1797
1798 for (int i = 0; i < context.operand_shapes.size(); i++) {
1799 auto* builder = &b_;
1800 auto ir_array = operand_arrays[i];
1801 fused_emitter.BindGenerator(
1802 fused_computation->parameter_instruction(i),
1803 [builder, ir_array](llvm_ir::IrArray::Index index) {
1804 return ir_array.EmitReadArrayElement(index, builder);
1805 });
1806 }
1807 TF_ASSIGN_OR_RETURN(
1808 auto element_generator,
1809 fused_emitter.GetGenerator(fused_computation->root_instruction()));
1810
1811 int unroll_factor;
1812 if (unroll_factor_override.has_value()) {
1813 unroll_factor = *unroll_factor_override;
1814 } else if (!MayPreventVectorization(fusion)) {
1815 unroll_factor = ComputeMaxUnrollFactor(fusion, hlo_module_config_);
1816 } else {
1817 unroll_factor = 1;
1818 }
1819
1820 bool few_waves = [fusion]() mutable {
1821 for (mlir::Operation& op : fusion.region().front()) {
1822 if (mlir::isa<mlir::TensorLoadOp, mlir::TensorStoreOp,
1823 mlir::lmhlo::TerminatorOp, mlir::mhlo::ReturnOp>(op)) {
1824 continue;
1825 }
1826 HloOpcode opcode = *MhloToHloOpcode(&op);
1827 if (HloInstruction::IsOpElementwise(opcode)) {
1828 continue;
1829 }
1830 if (auto broadcast = mlir::dyn_cast<mlir::mhlo::BroadcastOp>(op)) {
1831 if (broadcast.broadcast_sizes().size() == 0) {
1832 continue;
1833 }
1834 }
1835 return false;
1836 }
1837 return true;
1838 }();
1839
1840 Shape element_shape = context.output_shapes[0];
1841 LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
1842 element_shape, ir_emitter_context_->gpu_device_info(), unroll_factor,
1843 few_waves);
1844 UpdateLaunchDimensions(launch_dimensions, kernel_thunk,
1845 ir_emitter_context_->llvm_module());
1846 llvm::Type* index_type = GetIndexTypeForKernelFromMlir(
1847 fusion, launch_dimensions.launch_bound(), &b_);
1848
1849 if (context.output_shapes.size() > 1) {
1850 // Emit the tuple pointers in one thread. We could do this at any point in
1851 // the kernel, but we do it at the beginning in the hopes of reducing
1852 // register pressure, since we touch threadIdx.x and blockIdx.x at the
1853 // beginning of the kernel *anyway*.
1854 KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] {
1855 llvm_ir::EmitTuple(*tuple_output_array, output_element_arrays, &b_);
1856 });
1857 // For multioutput fusion, we need to emit each operand and the root.
1858 TF_RETURN_IF_ERROR(
1859 ParallelLoopEmitter(element_generator, output_element_arrays,
1860 launch_dimensions, &b_, unroll_factor)
1861 .EmitLoop(context.name, index_type));
1862 } else {
1863 TF_RETURN_IF_ERROR(
1864 ParallelLoopEmitter(element_generator, output_element_arrays[0],
1865 launch_dimensions, &b_, unroll_factor)
1866 .EmitLoop(context.name, index_type));
1867 }
1868
1869 b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator());
1870 return Status::OK();
1871 }
1872
HandleFusion(HloInstruction * fusion)1873 Status IrEmitterUnnested::HandleFusion(HloInstruction* fusion) {
1874 TF_ASSIGN_OR_RETURN(auto mlir_input, GetMlirEmitterInput(fusion));
1875 auto fusion_op = mlir::cast<mlir::lmhlo::FusionOp>(mlir_input.op);
1876
1877 if (fusion->IsInputFusion()) {
1878 switch (fusion->fused_expression_root()->opcode()) {
1879 case HloOpcode::kScatter: {
1880 TF_ASSIGN_OR_RETURN(
1881 const HloComputation* fused_computation,
1882 GetOrCreateSubComputationFromRegion(&fusion_op.region(),
1883 /*is_fusion=*/true));
1884 auto* root = fused_computation->root_instruction();
1885
1886 std::vector<std::unique_ptr<Thunk>> thunks;
1887 // The initialization from 'operand' is using different loop bounds, so
1888 // emit it in a separate kernel. Treat it like a loop fusion, writing to
1889 // the output buffer.
1890 {
1891 std::vector<llvm_ir::IrArray> ir_arrays;
1892 TF_ASSIGN_OR_RETURN(
1893 auto operand_thunk,
1894 BuildKernelThunkForMlir(mlir_input.op, Thunk::ThunkInfo(),
1895 mlir_input.extra_slice, &ir_arrays));
1896 thunks.push_back(std::move(operand_thunk));
1897
1898 GpuElementalIrEmitter operand_elemental_emitter(
1899 hlo_module_config_, ir_emitter_context_->llvm_module(), &b_,
1900 GetNestedComputer());
1901 FusedIrEmitter operand_fused_emitter(&operand_elemental_emitter);
1902 for (int i = 0; i < fused_computation->num_parameters(); i++) {
1903 auto fused_operand = fused_computation->parameter_instruction(i);
1904 operand_fused_emitter.BindGenerator(
1905 fused_operand, [this, &ir_arrays, i, fused_operand](
1906 const llvm_ir::IrArray::Index& index) {
1907 return ir_arrays[i].EmitReadArrayElement(
1908 index, &b_, fused_operand->name());
1909 });
1910 }
1911 TF_ASSIGN_OR_RETURN(
1912 auto generator,
1913 operand_fused_emitter.GetGenerator(root->operand(0)));
1914
1915 auto unroll_factor =
1916 ComputeMaxUnrollFactor(fusion_op, hlo_module_config_);
1917 const Shape& element_shape = root->shape();
1918 LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
1919 element_shape, ir_emitter_context_->gpu_device_info(),
1920 unroll_factor, /*few_waves=*/false);
1921 UpdateLaunchDimensions(launch_dimensions, thunks.back().get(),
1922 ir_emitter_context_->llvm_module());
1923 TF_RETURN_IF_ERROR(
1924 ParallelLoopEmitter(generator, ir_arrays.back(),
1925 launch_dimensions, &b_, unroll_factor)
1926 .EmitLoop(
1927 IrName(mlir::GetNameFromLoc(fusion_op.getLoc())),
1928 GetIndexTypeForKernelFromMlir(
1929 fusion_op, launch_dimensions.launch_bound(), &b_)));
1930 }
1931
1932 // Now build the actual scatter, reading and writing to the freshly
1933 // filled output buffer.
1934 {
1935 std::vector<llvm_ir::IrArray> ir_arrays;
1936 TF_ASSIGN_OR_RETURN(
1937 auto scatter_thunk,
1938 BuildKernelThunkForMlir(mlir_input.op, Thunk::ThunkInfo(),
1939 mlir_input.extra_slice, &ir_arrays));
1940 thunks.push_back(std::move(scatter_thunk));
1941 // Spin up a new fused emitter for the scatter kernel and emit it.
1942 GpuElementalIrEmitter scatter_elemental_emitter(
1943 hlo_module_config_, ir_emitter_context_->llvm_module(), &b_,
1944 GetNestedComputer());
1945 FusedIrEmitter scatter_fused_emitter(&scatter_elemental_emitter);
1946 for (int i = 0; i < fused_computation->num_parameters(); i++) {
1947 auto fused_operand = fused_computation->parameter_instruction(i);
1948 scatter_fused_emitter.BindGenerator(
1949 fused_operand, [this, &ir_arrays, i, fused_operand](
1950 const llvm_ir::IrArray::Index& index) {
1951 return ir_arrays[i].EmitReadArrayElement(
1952 index, &b_, fused_operand->name());
1953 });
1954 }
1955
1956 TF_ASSIGN_OR_RETURN(
1957 const auto dim_numbers,
1958 lhlo_scratch_emitter_->GetScatterDimensionNumbers(root));
1959
1960 ScatterDescriptor desc;
1961 desc.name = IrName(root);
1962 desc.operand_shape = root->operand(0)->shape();
1963 desc.scatter_indices_shape = root->operand(1)->shape();
1964 desc.updates_shape = root->operand(2)->shape();
1965 desc.dim_numbers = dim_numbers;
1966 desc.unique_indices = root->unique_indices();
1967 desc.update_computation = root->called_computations()[0];
1968 desc.output = ir_arrays.back();
1969 TF_ASSIGN_OR_RETURN(
1970 desc.scatter_indices_gen,
1971 scatter_fused_emitter.GetGenerator(root->operand(1)));
1972 TF_ASSIGN_OR_RETURN(
1973 desc.updates_gen,
1974 scatter_fused_emitter.GetGenerator(root->operand(2)));
1975 desc.get_index_type = [&](int64 launch_size) {
1976 return GetIndexTypeForKernel(root, launch_size, &b_);
1977 };
1978
1979 TF_RETURN_IF_ERROR(EmitScatter(desc, thunks.back().get()));
1980 }
1981 AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
1982 mlir_input.thunk_info, std::move(thunks)));
1983 return Status::OK();
1984 }
1985 // In the case of root tuple, it can be either reduce or slice input
1986 // fusion.
1987 case HloOpcode::kTuple: {
1988 if (IsInputFusibleSlices(mlir_input.op, /*verify_no_strides=*/false)) {
1989 return EmitInputFusibleNonStridedSlices(mlir_input);
1990 }
1991
1992 CHECK_GE(mlir::cast<mlir::lmhlo::FusionOp>(mlir_input.op)
1993 .getFusionResults()
1994 .size(),
1995 1);
1996 return EmitReductionFromOrToContiguousDimensions(mlir_input);
1997 }
1998 case HloOpcode::kReduce: {
1999 // HandleFusion specializes reduction from a multi-dimensional array to
2000 // a 1D array. The specialized version requires a initializer thunk that
2001 // initializes the output array to the initial value of the reduce.
2002 if (mlir_input.op->getNumResults() > 1) {
2003 // TODO(b/129089333): Support tiled vectorized variadic reduce.
2004 return Unimplemented(
2005 "Vectorized variadic reduce is not supported on GPU");
2006 }
2007 return EmitReductionFromOrToContiguousDimensions(mlir_input);
2008 }
2009 case HloOpcode::kSlice: {
2010 return EmitInputFusibleNonStridedSlices(mlir_input);
2011 }
2012 default:
2013 LOG(FATAL) << "Bad opcode for input fusion: "
2014 << fusion->fused_expression_root()->opcode();
2015 }
2016 } else if (CanEmitFusedDynamicUpdateSliceInPlaceForGpu(
2017 fusion_op, ir_emitter_context_->allocations())) {
2018 // Fusion node with dynamic-update-slice as the root where the op's input
2019 // (i.e. array to update) shares the same slice as its output. In this case
2020 // we have a special algorithm that modifies the output in place without
2021 // touching the un-updated elements.
2022 CHECK_EQ(1, GetHloOutputs(mlir_input.op).size());
2023
2024 // Set up kernel thunk and fused ir emitter.
2025 std::vector<llvm_ir::IrArray> ir_arrays;
2026 TF_ASSIGN_OR_RETURN(
2027 auto fusion_thunk,
2028 BuildKernelThunkForMlir(fusion_op, mlir_input.thunk_info,
2029 mlir_input.extra_slice, &ir_arrays));
2030
2031 GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
2032 ir_emitter_context_->llvm_module(),
2033 &b_, GetNestedComputer());
2034
2035 // Shape of the dynamic-update-slice's "update" operand.
2036 Shape update_shape = fusion->fused_expression_root()->operand(1)->shape();
2037
2038 // Array to write into. Because this is an in-place operation, this is the
2039 // same as operand 0's array.
2040 const IrArray& output_array = ir_arrays.back();
2041
2042 LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
2043 update_shape, ir_emitter_context_->gpu_device_info());
2044 UpdateLaunchDimensions(launch_dimensions, fusion_thunk.get(),
2045 ir_emitter_context_->llvm_module());
2046 AddThunkToThunkSequence(std::move(fusion_thunk));
2047
2048 FusedIrEmitter fused_emitter(&elemental_emitter);
2049
2050 TF_ASSIGN_OR_RETURN(
2051 const HloComputation* fused_computation,
2052 GetOrCreateSubComputationFromRegion(&fusion_op.region(),
2053 /*is_fusion=*/true));
2054
2055 for (int i = 0; i < fused_computation->num_parameters(); i++) {
2056 auto fused_operand = fused_computation->parameter_instruction(i);
2057 fused_emitter.BindGenerator(
2058 fused_operand, [this, &ir_arrays, i,
2059 fused_operand](const llvm_ir::IrArray::Index& index) {
2060 return ir_arrays[i].EmitReadArrayElement(index, &b_,
2061 fused_operand->name());
2062 });
2063 }
2064
2065 return llvm_ir::EmitParallelFusedDynamicUpdateSliceInPlace(
2066 fused_computation, output_array, &fused_emitter, launch_dimensions,
2067 &b_);
2068 }
2069
2070 CHECK_EQ(fusion->fusion_kind(), HloInstruction::FusionKind::kLoop)
2071 << ": " << fusion->ToString();
2072
2073 TF_ASSIGN_OR_RETURN(const bool matched_021,
2074 CheckAndEmitHloWithTile021(mlir_input));
2075 if (matched_021) {
2076 return Status::OK();
2077 }
2078
2079 return EmitLoopFusionFromMlir(mlir_input, fusion->shape());
2080 }
2081
HandleCopy(HloInstruction * copy)2082 Status IrEmitterUnnested::HandleCopy(HloInstruction* copy) {
2083 TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(copy));
2084 return EmitCopyForMlir(input);
2085 }
2086
EmitCopyForMlir(MlirEmitterInput input)2087 Status IrEmitterUnnested::EmitCopyForMlir(MlirEmitterInput input) {
2088 auto copy = mlir::cast<mlir::lmhlo::CopyOp>(input.op);
2089 auto operand_shape = TypeToShape(copy.operand().getType());
2090 auto output_shape = TypeToShape(copy.output().getType());
2091
2092 CHECK(ShapeUtil::Compatible(operand_shape, output_shape));
2093 auto maybe_slice = GetAllocationSliceForMlir(copy.operand());
2094 if (LayoutUtil::Equal(operand_shape.layout(), output_shape.layout()) &&
2095 maybe_slice.ok()) {
2096 // Copy the operand into the output if it's not the same buffer already.
2097 auto operand_buffer = *maybe_slice;
2098 auto destination_buffer = *GetAllocationSliceForMlir(copy.output());
2099 if (operand_buffer != destination_buffer) {
2100 AddThunkToThunkSequence(absl::make_unique<DeviceToDeviceCopyThunk>(
2101 input.thunk_info,
2102 /*source_address=*/operand_buffer,
2103 /*destination_buffer=*/destination_buffer,
2104 /*mem_size=*/
2105 ByteSizeOf(operand_shape)));
2106 }
2107 return Status::OK();
2108 }
2109 TF_ASSIGN_OR_RETURN(bool matched_021, CheckAndEmitHloWithTile021(input));
2110 if (matched_021) {
2111 return Status::OK();
2112 }
2113
2114 return EmitUsingElementalIrEmitter(input);
2115 }
2116
EmitExtraOutputsForReduce(absl::Span<const llvm_ir::IrArray> result_ir_arrays,const IrArray::Index & index,bool use_linear_index,absl::Span<const std::pair<llvm_ir::ElementGenerator,int>> extra_output_gens)2117 Status IrEmitterUnnested::EmitExtraOutputsForReduce(
2118 absl::Span<const llvm_ir::IrArray> result_ir_arrays,
2119 const IrArray::Index& index, bool use_linear_index,
2120 absl::Span<const std::pair<llvm_ir::ElementGenerator, int>>
2121 extra_output_gens) {
2122 // Compute all extra output values before writing them. This avoids
2123 // overwriting aliased input/output buffers before all reads occured.
2124 absl::InlinedVector<llvm::Value*, 8> extra_output_ir_values;
2125 for (int i = 0; i < extra_output_gens.size(); ++i) {
2126 TF_ASSIGN_OR_RETURN(llvm::Value* const extra_output_ir_value,
2127 extra_output_gens[i].first(index));
2128 extra_output_ir_values.push_back(extra_output_ir_value);
2129 }
2130 for (int i = 0; i < extra_output_gens.size(); ++i) {
2131 result_ir_arrays[extra_output_gens[i].second].EmitWriteArrayElement(
2132 index, extra_output_ir_values[i], &b_, use_linear_index);
2133 }
2134 return Status::OK();
2135 }
2136
HandleReduce(HloInstruction * reduce)2137 Status IrEmitterUnnested::HandleReduce(HloInstruction* reduce) {
2138 TF_ASSIGN_OR_RETURN(auto mlir_input, GetMlirEmitterInput(reduce));
2139 return EmitReduceFromMlir(mlir_input);
2140 }
2141
EmitReduceFromMlir(MlirEmitterInput mlir_input)2142 Status IrEmitterUnnested::EmitReduceFromMlir(MlirEmitterInput mlir_input) {
2143 if (GetHloOutputs(mlir_input.op).size() == 1 &&
2144 IsReductionFromOrToContiguousDimensions(mlir_input.op)) {
2145 return EmitReductionFromOrToContiguousDimensions(mlir_input);
2146 }
2147
2148 return EmitUsingElementalIrEmitter(mlir_input);
2149 }
2150
HandleTuple(HloInstruction * tuple)2151 Status IrEmitterUnnested::HandleTuple(HloInstruction* tuple) {
2152 // For all tuples, we expect the elements of the tuple to be directly consumed
2153 // by instructions that read from that tuple either directly, or through a
2154 // GTE instruction. This is possible we do not support "dynamic tuples" since
2155 // tuple-select is not supported. As a result, we never need to materialize a
2156 // tuple (which has a runtime representation of an array of pointers) in
2157 // memory at runtime. So there is no need to generate any code for tuples.
2158 return Status::OK();
2159 }
2160
HandleGetTupleElement(HloInstruction *)2161 Status IrEmitterUnnested::HandleGetTupleElement(HloInstruction*) {
2162 // GetTupleElement IR is emitted in the IR context of the user instruction,
2163 // and so we do not build a kernel for GetTupleElement instructions.
2164 return Status::OK();
2165 }
2166
AssertNonDeterminismIsOkay(const string & op_name)2167 Status IrEmitterUnnested::AssertNonDeterminismIsOkay(const string& op_name) {
2168 if (hlo_module_config_.debug_options().xla_gpu_deterministic_ops()) {
2169 return Unimplemented(
2170 "HLO instruction %s does not have a deterministic implementation, "
2171 "but run-to-run determinism is required by "
2172 "--xla_gpu_deterministic_ops.",
2173 op_name);
2174 }
2175 return Status::OK();
2176 }
2177
HandleSelectAndScatter(HloInstruction * select_and_scatter)2178 Status IrEmitterUnnested::HandleSelectAndScatter(
2179 HloInstruction* select_and_scatter) {
2180 const Window& window = select_and_scatter->window();
2181 const auto* operand = select_and_scatter->operand(0);
2182 const auto* source = select_and_scatter->operand(1);
2183 const int64 rank = operand->shape().rank();
2184 CHECK_EQ(rank, source->shape().rank());
2185 CHECK_EQ(rank, window.dimensions_size());
2186
2187 // TODO(b/31410564): Implement dilation rate for select-and-scatter.
2188 if (window_util::HasDilation(window)) {
2189 return Unimplemented(
2190 "Dilation for SelectAndScatter not implemented on GPU.");
2191 }
2192
2193 TF_RETURN_IF_ERROR(AssertNonDeterminismIsOkay(select_and_scatter->name()));
2194
2195 TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(select_and_scatter));
2196 return EmitSelectAndScatterFromMlir(input);
2197 }
2198
EmitSelectAndScatterFromMlir(MlirEmitterInput mlir_input)2199 Status IrEmitterUnnested::EmitSelectAndScatterFromMlir(
2200 MlirEmitterInput mlir_input) {
2201 auto select_and_scatter_op =
2202 mlir::cast<mlir::lmhlo::SelectAndScatterOp>(mlir_input.op);
2203
2204 std::string name = mlir::GetNameFromLoc(select_and_scatter_op.getLoc());
2205
2206 std::vector<std::unique_ptr<Thunk>> thunks;
2207 thunks.emplace_back();
2208 TF_ASSIGN_OR_RETURN(thunks.back(),
2209 BuildInitializerThunkForMlir(
2210 mlir_input.op, select_and_scatter_op.init_value(),
2211 select_and_scatter_op.out()));
2212
2213 std::vector<llvm_ir::IrArray> ir_arrays;
2214 thunks.emplace_back();
2215 // Init value is not needed in IR emission.
2216 TF_ASSIGN_OR_RETURN(
2217 thunks.back(),
2218 BuildKernelThunkForMlir(
2219 select_and_scatter_op,
2220 {select_and_scatter_op.operand(), select_and_scatter_op.source(),
2221 select_and_scatter_op.out()},
2222 Thunk::ThunkInfo(), mlir_input.extra_slice, &ir_arrays));
2223
2224 CHECK_EQ(ir_arrays.size(), 3);
2225 const IrArray& operand_array = ir_arrays[0];
2226 const IrArray& source_array = ir_arrays[1];
2227 const IrArray& out_array = ir_arrays[2];
2228
2229 auto select_and_scatter_thunk = absl::make_unique<SequentialThunk>(
2230 mlir_input.thunk_info, std::move(thunks));
2231
2232 const Shape source_shape =
2233 TypeToShape(select_and_scatter_op.source().getType());
2234 const Shape operand_shape =
2235 TypeToShape(select_and_scatter_op.operand().getType());
2236 const int64 rank = operand_shape.rank();
2237
2238 LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
2239 source_shape, ir_emitter_context_->gpu_device_info());
2240 llvm::Type* index_type = GetIndexTypeForKernelFromMlir(
2241 select_and_scatter_op, launch_dimensions.launch_bound(), &b_);
2242 auto index_typed_constant = [&](uint64 c) -> llvm::Constant* {
2243 return llvm::ConstantInt::get(index_type, c);
2244 };
2245
2246 // kSelectAndScatter is implemented as two kernel launches: the first launch
2247 // initializes the output array to the given initial value,
2248 // and the second accumulates the "source" matrix to the
2249 // selected elements in the output array. The first launch is already
2250 // implemented by the initializer thunk generated earlier, so this function
2251 // only needs to take care of the select-and-scatter part.
2252 //
2253 // Pseudo code for select-and-scatter:
2254 //
2255 // for (coordinates S in the source): # This loop is parallel.
2256 // initialized_flag = false
2257 // for (coordinates W in the window):
2258 // I = S * stride + W - pad_low
2259 // if I within bounds of operand:
2260 // if !(initialized_flag and select(selected_value, operand(I))):
2261 // selected_value = operand(I)
2262 // selected_index = I
2263 // initialized_flag = true
2264 // output(selected_index) = scatter(output(selected_index), source(S))
2265 auto loop_body_emitter = [&](const IrArray::Index& source_index) -> Status {
2266 // Allocate space to keep the currently selected value, its index, and a
2267 // boolean flag if the value is initialized. The initialized_flag is set
2268 // false.
2269 llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry(
2270 llvm_ir::PrimitiveTypeToIrType(operand_shape.element_type(),
2271 ir_emitter_context_->llvm_module()),
2272 "selected_value_address", &b_);
2273
2274 llvm::Value* selected_index_address =
2275 llvm_ir::EmitAllocaAtFunctionEntryWithCount(
2276 index_type, index_typed_constant(rank), "selected_index_address",
2277 &b_);
2278
2279 llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry(
2280 b_.getInt1Ty(), "initialized_flag_address", &b_);
2281 Store(b_.getInt1(false), initialized_flag_address);
2282
2283 // Create the inner loop to iterate over the window.
2284 llvm_ir::ForLoopNest window_loops(absl::StrCat(name, "inner"), &b_,
2285 index_type);
2286
2287 DimensionVector window_size;
2288 mlir::DenseIntElementsAttr window_dimensions =
2289 select_and_scatter_op.window_dimensions().getValue();
2290 for (const auto& dim : window_dimensions) {
2291 window_size.push_back(dim.getSExtValue());
2292 CHECK_GT(dim.getSExtValue(), 0);
2293 }
2294
2295 const IrArray::Index window_index = window_loops.AddLoopsForShape(
2296 ShapeUtil::MakeShape(operand_shape.element_type(), window_size),
2297 "window");
2298 llvm_ir::SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(),
2299 &b_);
2300
2301 // Compute the operand index to visit and evaluate the condition whether the
2302 // operand index is within the bounds. The unsigned comparison includes
2303 // checking whether the operand index >= 0.
2304 std::vector<llvm::Value*> operand_multi_index(source_index.size());
2305 llvm::Value* in_bounds_condition = b_.getInt1(true);
2306
2307 auto strides = *select_and_scatter_op.window_strides();
2308 auto paddings = *select_and_scatter_op.padding();
2309
2310 for (auto stride_and_padding :
2311 llvm::enumerate(llvm::zip(strides, paddings))) {
2312 const int i = stride_and_padding.index();
2313 int64 stride = std::get<0>(stride_and_padding.value()).getSExtValue();
2314 int64 padding = std::get<1>(stride_and_padding.value()).getSExtValue();
2315
2316 llvm::Value* strided_index =
2317 NSWMul(source_index[i], index_typed_constant(stride));
2318 operand_multi_index[i] = NSWSub(NSWAdd(strided_index, window_index[i]),
2319 index_typed_constant(padding));
2320 llvm::Value* index_condition = ICmpULT(
2321 operand_multi_index[i],
2322 index_typed_constant(ShapeUtil::GetDimension(operand_shape, i)));
2323 in_bounds_condition = And(in_bounds_condition, index_condition);
2324 }
2325
2326 // Only need to do something if the operand index is within the bounds.
2327 // First check if the initialized_flag is set.
2328 llvm_ir::LlvmIfData if_in_bounds =
2329 llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &b_);
2330 llvm_ir::SetToFirstInsertPoint(if_in_bounds.true_block, &b_);
2331 llvm_ir::LlvmIfData if_initialized = llvm_ir::EmitIfThenElse(
2332 Load(initialized_flag_address), "initialized", &b_);
2333
2334 // If the initialized_flag is false, initialize the selected value and index
2335 // with the currently visiting operand.
2336 llvm_ir::SetToFirstInsertPoint(if_initialized.false_block, &b_);
2337 const auto save_operand_index = [&](const IrArray::Index& operand_index) {
2338 for (int64 i = 0; i < rank; ++i) {
2339 llvm::Value* selected_index_address_slot =
2340 InBoundsGEP(selected_index_address, {b_.getInt32(i)});
2341 Store(operand_index[i], selected_index_address_slot);
2342 }
2343 };
2344 IrArray::Index operand_index(operand_multi_index, operand_shape,
2345 index_type);
2346 llvm::Value* operand_data =
2347 operand_array.EmitReadArrayElement(operand_index, &b_);
2348 Store(operand_data, selected_value_address);
2349 save_operand_index(operand_index);
2350 Store(b_.getInt1(true), initialized_flag_address);
2351
2352 // If the initialized_flag is true, call the `select` function to
2353 // potentially update the selected value and index with the currently
2354 // visiting operand.
2355 llvm_ir::SetToFirstInsertPoint(if_initialized.true_block, &b_);
2356 llvm::Value* operand_address =
2357 operand_array.EmitArrayElementAddress(operand_index, &b_);
2358 llvm::Value* select_return_buffer = llvm_ir::EmitAllocaAtFunctionEntry(
2359 llvm_ir::PrimitiveTypeToIrType(PRED,
2360 ir_emitter_context_->llvm_module()),
2361 "select_return_buffer", &b_);
2362
2363 TF_ASSIGN_OR_RETURN(
2364 const HloComputation* select_computation,
2365 GetOrCreateSubComputationFromRegion(&select_and_scatter_op.select(),
2366 /*is_fusion=*/false));
2367
2368 TF_RETURN_IF_ERROR(EmitCallToNestedComputation(
2369 *select_computation, {selected_value_address, operand_address},
2370 select_return_buffer));
2371 llvm::Value* result = Load(select_return_buffer);
2372
2373 // If the 'select' function returns false, update the selected value and the
2374 // index to the currently visiting operand.
2375 llvm::Value* cond = ICmpNE(
2376 result,
2377 llvm::ConstantInt::get(llvm_ir::PrimitiveTypeToIrType(
2378 PRED, ir_emitter_context_->llvm_module()),
2379 0),
2380 "boolean_predicate");
2381 llvm_ir::LlvmIfData if_select_lhs =
2382 llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &b_);
2383 llvm_ir::SetToFirstInsertPoint(if_select_lhs.false_block, &b_);
2384 Store(Load(operand_address), selected_value_address);
2385 save_operand_index(operand_index);
2386
2387 // After iterating over the window elements, scatter the source element to
2388 // the selected index of the output. The value we store at the output
2389 // location is computed by calling the `scatter` function with the source
2390 // value and the current output value.
2391 llvm_ir::SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(),
2392 &b_);
2393 std::vector<llvm::Value*> selected_multi_index;
2394 for (int64 i = 0; i < rank; ++i) {
2395 llvm::Value* selected_index_address_slot =
2396 InBoundsGEP(selected_index_address, {b_.getInt32(i)});
2397 selected_multi_index.push_back(Load(selected_index_address_slot));
2398 }
2399 const Shape output_shape =
2400 TypeToShape(select_and_scatter_op.out().getType());
2401 llvm::Value* source_value_address =
2402 source_array.EmitArrayElementAddress(source_index, &b_);
2403 IrArray::Index selected_index(selected_multi_index, output_shape,
2404 operand_index.GetType());
2405 llvm::Value* output_value_address =
2406 out_array.EmitArrayElementAddress(selected_index, &b_);
2407
2408 TF_ASSIGN_OR_RETURN(
2409 const HloComputation* scatter_computation,
2410 GetOrCreateSubComputationFromRegion(&select_and_scatter_op.scatter(),
2411 /*is_fusion=*/false));
2412
2413 return EmitAtomicOperationForNestedComputation(
2414 *scatter_computation, output_value_address, source_value_address);
2415 };
2416
2417 UpdateLaunchDimensions(
2418 launch_dimensions,
2419 // IrEmitterUnnested implements kSelectAndScatter as a SequentialThunk
2420 // consisting of two thunks, an initializer KernelThunk that initializes
2421 // the output and another KernelThunk that accumulates the scattered
2422 // elements.
2423 select_and_scatter_thunk->thunks().back().get(),
2424 ir_emitter_context_->llvm_module());
2425 AddThunkToThunkSequence(std::move(select_and_scatter_thunk));
2426 return ParallelLoopEmitter(loop_body_emitter, source_shape, launch_dimensions,
2427 &b_)
2428 .EmitLoop(name, index_type);
2429 }
2430
HandleWhile(HloInstruction * xla_while)2431 Status IrEmitterUnnested::HandleWhile(HloInstruction* xla_while) {
2432 HloComputation* condition = xla_while->while_condition();
2433 TF_RET_CHECK(ShapeUtil::IsScalar(condition->root_instruction()->shape()) &&
2434 condition->root_instruction()->shape().element_type() == PRED)
2435 << "While condition computation must return bool";
2436 // Build ForThunk for conformant while loops, otherwise build WhileThunk.
2437 auto config = xla_while->backend_config<WhileLoopBackendConfig>();
2438 if (config.ok() && config.ValueOrDie().has_known_trip_count()) {
2439 TF_ASSIGN_OR_RETURN(
2440 auto thunk,
2441 BuildForThunk(xla_while, config.ValueOrDie().known_trip_count().n()));
2442 AddThunkToThunkSequence(std::move(thunk));
2443 } else {
2444 TF_ASSIGN_OR_RETURN(auto thunk, BuildWhileThunk(xla_while));
2445 AddThunkToThunkSequence(std::move(thunk));
2446 }
2447 return Status::OK();
2448 }
2449
HandleRng(HloInstruction * rng)2450 Status IrEmitterUnnested::HandleRng(HloInstruction* rng) {
2451 return Unimplemented("Rng should be expanded for GPU.");
2452 }
2453
HandleRngGetAndUpdateState(HloInstruction * rng_state)2454 Status IrEmitterUnnested::HandleRngGetAndUpdateState(
2455 HloInstruction* rng_state) {
2456 TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(rng_state));
2457 return EmitRngGetAndUpdateState(input);
2458 }
2459
EmitRngGetAndUpdateState(MlirEmitterInput mlir_input)2460 Status IrEmitterUnnested::EmitRngGetAndUpdateState(
2461 MlirEmitterInput mlir_input) {
2462 auto rng_op =
2463 mlir::dyn_cast<mlir::lmhlo::RngGetAndUpdateStateOp>(mlir_input.op);
2464
2465 // Emit a kernel to increment the global state for Philox RNG algorithm.
2466 std::vector<llvm_ir::IrArray> ir_arrays;
2467 TF_ASSIGN_OR_RETURN(
2468 auto kernel_thunk,
2469 BuildKernelThunkForMlir(rng_op, rng_op.state(), mlir_input.thunk_info,
2470 mlir_input.extra_slice, &ir_arrays));
2471 AddThunkToThunkSequence(std::move(kernel_thunk));
2472
2473 llvm::Value* old_state =
2474 llvm_ir::RngGetAndUpdateState(rng_op.delta(), module_, &b_);
2475
2476 const Shape shape = TypeToShape(rng_op.state().getType());
2477
2478 llvm::Value* output_address = ir_arrays[0].EmitArrayElementAddress(
2479 llvm_ir::IrArray::Index(
2480 /*linear=*/b_.getInt64(0), shape, &b_),
2481 &b_, "rng_state_address");
2482 output_address = BitCast(
2483 output_address, llvm::PointerType::get(
2484 old_state->getType(),
2485 output_address->getType()->getPointerAddressSpace()));
2486 Store(old_state, output_address);
2487
2488 return Status::OK();
2489 }
2490
HandleScatter(HloInstruction * scatter)2491 Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) {
2492 if (!scatter->unique_indices()) {
2493 TF_RETURN_IF_ERROR(AssertNonDeterminismIsOkay(scatter->name()));
2494 }
2495 TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(scatter));
2496 return EmitScatterFromMlir(input);
2497 }
2498
EmitScatterFromMlir(MlirEmitterInput mlir_input)2499 Status IrEmitterUnnested::EmitScatterFromMlir(MlirEmitterInput mlir_input) {
2500 std::vector<std::unique_ptr<Thunk>> thunks;
2501
2502 auto scatter_op = mlir::cast<mlir::lmhlo::ScatterOp>(mlir_input.op);
2503
2504 TF_ASSIGN_OR_RETURN(auto operand_buffer,
2505 GetAllocationSliceForMlir(scatter_op.operand()));
2506 TF_ASSIGN_OR_RETURN(auto output_buffer,
2507 GetAllocationSliceForMlir(scatter_op.output()));
2508
2509 // Copy the operand into the output if it's not the same buffer already.
2510 if (operand_buffer != output_buffer) {
2511 thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
2512 Thunk::ThunkInfo(),
2513 /*source_address=*/operand_buffer,
2514 /*destination_buffer=*/output_buffer,
2515 /*mem_size=*/
2516 ShapeUtil::ByteSizeOf(TypeToShape(scatter_op.output().getType()))));
2517 }
2518
2519 // Create kernel thunk for all operands except the first one (`operand`). The
2520 // code generated for scatter below assumes that the input operand is already
2521 // copied into the output, so does not use it in codegen.
2522 std::vector<llvm_ir::IrArray> ir_arrays;
2523 thunks.emplace_back();
2524 TF_ASSIGN_OR_RETURN(
2525 thunks.back(),
2526 BuildKernelThunkForMlir(scatter_op, scatter_op.getOperands().drop_front(),
2527 mlir_input.thunk_info, mlir_input.extra_slice,
2528 &ir_arrays));
2529
2530 CHECK_EQ(ir_arrays.size(), 3);
2531 const IrArray& scatter_indices = ir_arrays[0];
2532 const IrArray& updates = ir_arrays[1];
2533 const IrArray& output = ir_arrays[2];
2534
2535 auto get_index_type = [&](int64 launch_size) {
2536 return GetIndexTypeForKernelFromMlir(scatter_op, launch_size, &b_);
2537 };
2538
2539 TF_RETURN_IF_ERROR(EmitScatter(
2540 thunks.back().get(), scatter_op, output,
2541 /*scatter_indices_gen=*/
2542 [&](const IrArray::Index& index) {
2543 return scatter_indices.EmitReadArrayElement(index, &b_,
2544 "scatter_index");
2545 },
2546 /*updates_gen=*/
2547 [&](const IrArray::Index& index) {
2548 return updates.EmitReadArrayElement(index, &b_, "update");
2549 },
2550 /* get_index_type=*/
2551 get_index_type));
2552
2553 // Elide the sequential thunk if there's no copy.
2554 if (thunks.size() == 1) {
2555 AddThunkToThunkSequence(std::move(thunks[0]));
2556 } else {
2557 AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
2558 mlir_input.thunk_info, std::move(thunks)));
2559 }
2560
2561 return Status::OK();
2562 }
2563
EmitScatter(Thunk * thunk,mlir::lmhlo::ScatterOp scatter,const llvm_ir::IrArray & output,const llvm_ir::ElementGenerator & scatter_indices_gen,const llvm_ir::ElementGenerator & updates_gen,std::function<llvm::Type * (int64)> get_index_type)2564 Status IrEmitterUnnested::EmitScatter(
2565 Thunk* thunk, mlir::lmhlo::ScatterOp scatter,
2566 const llvm_ir::IrArray& output,
2567 const llvm_ir::ElementGenerator& scatter_indices_gen,
2568 const llvm_ir::ElementGenerator& updates_gen,
2569 std::function<llvm::Type*(int64)> get_index_type) {
2570 const Shape operand_shape = TypeToShape(scatter.operand().getType());
2571 CHECK(
2572 ShapeUtil::Equal(TypeToShape(scatter.output().getType()), operand_shape));
2573
2574 TF_ASSIGN_OR_RETURN(
2575 const HloComputation* update_computation,
2576 GetOrCreateSubComputationFromRegion(&scatter.update_computation(),
2577 /*is_fusion=*/false));
2578
2579 ScatterDescriptor desc;
2580 desc.name = mlir::GetNameFromLoc(scatter.getLoc());
2581 desc.operand_shape = operand_shape;
2582 desc.scatter_indices_shape = TypeToShape(scatter.scatter_indices().getType());
2583 desc.updates_shape = TypeToShape(scatter.updates().getType());
2584 desc.dim_numbers = scatter.scatter_dimension_numbers();
2585 desc.unique_indices = scatter.unique_indices();
2586 desc.update_computation = update_computation;
2587 desc.output = output;
2588 desc.scatter_indices_gen = scatter_indices_gen;
2589 desc.updates_gen = updates_gen;
2590 desc.get_index_type = get_index_type;
2591 return EmitScatter(desc, thunk);
2592 }
2593
EmitScatter(const ScatterDescriptor & desc,Thunk * thunk)2594 Status IrEmitterUnnested::EmitScatter(const ScatterDescriptor& desc,
2595 Thunk* thunk) {
2596 if (!desc.unique_indices) {
2597 TF_RETURN_IF_ERROR(AssertNonDeterminismIsOkay(desc.name));
2598 }
2599 auto loop_body_emitter = [&](const IrArray::Index& index) -> Status {
2600 std::vector<llvm::Value*> raw_window_multidim;
2601 std::vector<llvm::Value*> input_scatter_multidim;
2602 std::vector<int64> raw_window_bounds;
2603
2604 // Partition the index into window indices and scatter indices.
2605 for (int64 i = 0, e = index.size(); i != e; ++i) {
2606 // For window indices also remember the window size, this comes in handy
2607 // later.
2608 if (BinarySearchDenseElementsAttr(desc.dim_numbers.update_window_dims(),
2609 i)) {
2610 raw_window_multidim.push_back(index[i]);
2611 raw_window_bounds.push_back(desc.updates_shape.dimensions(i));
2612 } else {
2613 input_scatter_multidim.push_back(index[i]);
2614 }
2615 }
2616 DCHECK_EQ(raw_window_multidim.size(),
2617 desc.dim_numbers.update_window_dims().size());
2618
2619 // Apply inserted_window_dims to the window dimensions.
2620 int64 raw_window_multidim_idx = 0;
2621 std::vector<llvm::Value*> input_window_multidim;
2622 std::vector<int64> input_window_bounds;
2623
2624 for (int64 i = 0, e = desc.operand_shape.rank(); i != e; ++i) {
2625 if (BinarySearchDenseElementsAttr(desc.dim_numbers.inserted_window_dims(),
2626 i)) {
2627 input_window_bounds.push_back(1); // Trivial dimension.
2628 input_window_multidim.push_back(index.GetConstantWithIndexType(0));
2629 } else {
2630 input_window_bounds.push_back(
2631 raw_window_bounds[raw_window_multidim_idx]);
2632 input_window_multidim.push_back(
2633 raw_window_multidim[raw_window_multidim_idx]);
2634 ++raw_window_multidim_idx;
2635 }
2636 }
2637 DCHECK_EQ(input_window_multidim.size(), desc.operand_shape.rank());
2638
2639 // Insert a 1 dimension at the end if index_vector_dim requests one.
2640 Shape scatter_indices_shape_fixed = desc.scatter_indices_shape;
2641 if (desc.dim_numbers.index_vector_dim().getInt() ==
2642 desc.scatter_indices_shape.rank()) {
2643 scatter_indices_shape_fixed.add_dimensions(1);
2644 scatter_indices_shape_fixed.mutable_layout()->add_minor_to_major(
2645 desc.dim_numbers.index_vector_dim().getInt());
2646 }
2647
2648 // Now load the indices corresponding to the current window from
2649 // scatter_indices.
2650 std::vector<llvm::Value*> raw_scatter_index_multidim =
2651 input_scatter_multidim;
2652 raw_scatter_index_multidim.insert(
2653 raw_scatter_index_multidim.begin() +
2654 desc.dim_numbers.index_vector_dim().getInt(),
2655 nullptr);
2656 llvm::Value* is_in_bounds = b_.getTrue();
2657 for (int64 i = 0,
2658 e = desc.dim_numbers.scatter_dims_to_operand_dims().size();
2659 i != e; ++i) {
2660 // Our index is stored along index_vector_dim, insert that into the lookup
2661 // index into scatter_indices.
2662 raw_scatter_index_multidim[desc.dim_numbers.index_vector_dim().getInt()] =
2663 index.GetConstantWithIndexType(i);
2664 llvm_ir::IrArray::Index raw_scatter_index_index(
2665 raw_scatter_index_multidim, scatter_indices_shape_fixed,
2666 index.GetType());
2667
2668 int64 operand_dim =
2669 desc.dim_numbers.scatter_dims_to_operand_dims().getValue<int64>(i);
2670 TF_ASSIGN_OR_RETURN(
2671 llvm::Value* const loaded_scatter_index,
2672 desc.scatter_indices_gen(raw_scatter_index_index.SourceIndexOfReshape(
2673 scatter_indices_shape_fixed, desc.scatter_indices_shape, &b_)));
2674 // And add the index to our window index. This yields the output index.
2675 llvm::Value* casted_scatter_index =
2676 IntCast(loaded_scatter_index, index.GetType(),
2677 /*isSigned=*/true);
2678 llvm::Value* dim_offset =
2679 Add(input_window_multidim[operand_dim], casted_scatter_index);
2680 input_window_multidim[operand_dim] = dim_offset;
2681
2682 // Also do the bounds check now.
2683 int64 max_index = desc.operand_shape.dimensions(operand_dim) -
2684 input_window_bounds[operand_dim] + 1;
2685 // is_in_bounds = index >= 0 && index < dim_size-window_size+1
2686 // --> index u< dim_size-window_size+1
2687 is_in_bounds =
2688 And(is_in_bounds, ICmpULT(casted_scatter_index,
2689 index.GetConstantWithIndexType(max_index)));
2690 }
2691
2692 llvm_ir::LlvmIfData if_window_in_bounds_data = llvm_ir::EmitIfThenElse(
2693 is_in_bounds, "scatter.in_bounds", &b_, /*emit_else=*/false);
2694 llvm_ir::SetToFirstInsertPoint(if_window_in_bounds_data.true_block, &b_);
2695 // All done, now just read from the calculated input from the window, and do
2696 // an atomic store to the calculated location in the output.
2697 llvm_ir::IrArray::Index input_window_index(
2698 input_window_multidim, desc.output.GetShape(), index.GetType());
2699 llvm::Value* output_address =
2700 desc.output.EmitArrayElementAddress(input_window_index, &b_);
2701 llvm::Value* input_address = llvm_ir::EmitAllocaAtFunctionEntry(
2702 llvm_ir::PrimitiveTypeToIrType(desc.updates_shape.element_type(),
2703 module_),
2704 "input_address", &b_);
2705 TF_ASSIGN_OR_RETURN(llvm::Value* const input_ir_value,
2706 desc.updates_gen(index));
2707 Store(input_ir_value, input_address);
2708
2709 if (!desc.unique_indices) {
2710 return EmitAtomicOperationForNestedComputation(
2711 *desc.update_computation, output_address, input_address);
2712 } else {
2713 return EmitCallToNestedComputation(*desc.update_computation,
2714 {output_address, input_address},
2715 output_address);
2716 }
2717 };
2718
2719 // Launch a kernel that reads every element in the updates tensor. We could
2720 // also do one kernel per window instead if bounds checks turn out to be a
2721 // bottleneck.
2722 LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
2723 desc.updates_shape, ir_emitter_context_->gpu_device_info());
2724 UpdateLaunchDimensions(launch_dimensions, thunk,
2725 ir_emitter_context_->llvm_module());
2726
2727 return ParallelLoopEmitter(loop_body_emitter, desc.updates_shape,
2728 launch_dimensions, &b_)
2729 .EmitLoop(desc.name,
2730 desc.get_index_type(launch_dimensions.launch_bound()));
2731 }
2732
2733 // This transformation should be migrated off. See b/171334474.
2734 StatusOr<HloComputation*>
GetOrCreateSubComputationFromRegion(mlir::Region * region,bool is_fusion)2735 IrEmitterUnnested::GetOrCreateSubComputationFromRegion(mlir::Region* region,
2736 bool is_fusion) {
2737 std::unique_ptr<HloModule>& module = scratch_nested_computations_[region];
2738 if (module == nullptr) {
2739 std::vector<Shape> operand_shapes, output_shapes;
2740 if (is_fusion) {
2741 mlir::Operation* clone = region->getParentOp()->clone();
2742 region = &mlir::cast<mlir::lmhlo::FusionOp>(clone).region();
2743 TF_RETURN_IF_ERROR(
2744 ProcessFusionForConversion(region, &operand_shapes, &output_shapes));
2745 }
2746
2747 xla::XlaComputation xla_computation;
2748 mlir::MlirToHloConversionOptions options;
2749 options.propagate_layouts = true;
2750 TF_RETURN_IF_ERROR(
2751 ConvertRegionToComputation(region, &xla_computation, options));
2752
2753 if (is_fusion) {
2754 region->getParentOp()->erase();
2755 }
2756
2757 TF_ASSIGN_OR_RETURN(auto program_shape, xla_computation.GetProgramShape());
2758 TF_ASSIGN_OR_RETURN(
2759 module, HloModule::CreateFromProto(xla_computation.proto(),
2760 HloModuleConfig(program_shape)));
2761
2762 if (is_fusion) {
2763 HloComputation* fused_computation = module->entry_computation();
2764
2765 CHECK_EQ(operand_shapes.size(), fused_computation->num_parameters());
2766 for (int i = 0; i < fused_computation->num_parameters(); i++) {
2767 *fused_computation->parameter_instruction(i)
2768 ->mutable_shape()
2769 ->mutable_layout() = operand_shapes[i].layout();
2770 }
2771 HloInstruction* root = fused_computation->root_instruction();
2772 // Manually fold Tuple(GTE(a, 0), GTE(a, 1), GTE(a, 2), ...) to a.
2773 // FusedIrEmitter doesn't take GTE ops because we aim to elimiate tuples
2774 // as much as possible.
2775 if (root->opcode() == HloOpcode::kTuple) {
2776 [&] {
2777 HloInstruction* real_root = nullptr;
2778 int expected_tuple_index = 0;
2779 for (HloInstruction* operand : root->operands()) {
2780 if (operand->opcode() != HloOpcode::kGetTupleElement) {
2781 return;
2782 }
2783 if (real_root == nullptr) {
2784 real_root = operand->mutable_operand(0);
2785 } else if (real_root != operand->operand(0)) {
2786 return;
2787 }
2788 if (expected_tuple_index != operand->tuple_index()) {
2789 return;
2790 }
2791 expected_tuple_index++;
2792 }
2793 fused_computation->set_root_instruction(real_root);
2794 std::vector<HloInstruction*> to_be_removed;
2795 to_be_removed.push_back(root);
2796 for (HloInstruction* operand : root->operands()) {
2797 to_be_removed.push_back(operand);
2798 }
2799 for (auto instr : to_be_removed) {
2800 TF_CHECK_OK(fused_computation->RemoveInstruction(instr));
2801 }
2802
2803 root = real_root;
2804 }();
2805 }
2806
2807 if (output_shapes.size() > 1) {
2808 CHECK(root->shape().IsTuple());
2809 CHECK_EQ(root->shape().tuple_shapes_size(), output_shapes.size());
2810
2811 for (int i = 0; i < output_shapes.size(); i++) {
2812 *root->mutable_shape()->mutable_tuple_shapes(i) = output_shapes.at(i);
2813 }
2814 } else {
2815 CHECK_EQ(1, output_shapes.size());
2816 *root->mutable_shape() = output_shapes[0];
2817 }
2818 }
2819 // Post-process the generated computation:
2820 // * Sanitize constant names, so that they can be used as LLVM global
2821 // symbols.
2822 // * Propagate layouts for tuple types.
2823 for (HloComputation* computation : module->computations()) {
2824 for (HloInstruction* instr : computation->MakeInstructionPostOrder()) {
2825 if (instr->opcode() == HloOpcode::kConstant) {
2826 // Notice that IR emitters use the name of constants as LLVM symbol
2827 // names, therefore it's important to not let these constants in the
2828 // new module collide with constants in the original module by names.
2829 // Unique them by prepending the module name.
2830 //
2831 // TODO(timshen): A better solution would be to plumb the exact
2832 // constant names through original HLO -> LHLO -> MHLO -> HLO. This is
2833 // hard because XLA builder doesn't support setting names. Revisit
2834 // this once we get rid of this function, or don't rely on the op name
2835 // (which shouldn't be the identity) to generate LLVM symbols.
2836 instr->SetAndSanitizeName(llvm_ir::SanitizeConstantName(
2837 module->name() + "_" + instr->name()));
2838 }
2839 if (instr->shape().IsTuple() &&
2840 computation == module->entry_computation() &&
2841 instr != computation->root_instruction()) {
2842 return InternalError("Non-root tuple types are not handled.");
2843 }
2844 }
2845 }
2846 }
2847 return module->entry_computation();
2848 }
2849
HandleSort(HloInstruction * sort)2850 Status IrEmitterUnnested::HandleSort(HloInstruction* sort) {
2851 TF_ASSIGN_OR_RETURN(auto mlir_input, GetMlirEmitterInput(sort));
2852 return EmitSortFromMlir(mlir_input);
2853 }
2854
EmitSortFromMlir(MlirEmitterInput mlir_input)2855 Status IrEmitterUnnested::EmitSortFromMlir(MlirEmitterInput mlir_input) {
2856 auto sort_op = mlir::cast<mlir::lmhlo::SortOp>(mlir_input.op);
2857 MlirEmitterContext context;
2858 context.SetOperation(sort_op);
2859
2860 std::vector<std::unique_ptr<Thunk>> thunks;
2861
2862 const Shape& keys_shape = context.operand_shapes[0];
2863 int64 dimension_to_sort = sort_op.dimension();
2864 for (int64 i = 0; i < context.operand_shapes.size(); ++i) {
2865 // We assume that the layout of all involved operands and outputs is the
2866 // same.
2867 TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(keys_shape,
2868 context.operand_shapes[i]));
2869 TF_RET_CHECK(
2870 LayoutUtil::LayoutsInShapesEqual(keys_shape, context.output_shapes[i]));
2871
2872 // If possible, we share buffers. If that is not possible, we need to copy
2873 // the values, because the emitter does the sorting in-place.
2874 TF_ASSIGN_OR_RETURN(auto destination_buffer,
2875 GetAllocationSliceForMlir(sort_op.output()[i]));
2876 TF_ASSIGN_OR_RETURN(auto source_address,
2877 GetAllocationSliceForMlir(sort_op.operands()[i]));
2878 if (destination_buffer != source_address) {
2879 // TODO(b/26783907): Figure out why we never seem to share buffers for
2880 // key/value sort.
2881 VLOG(2) << context.name << " requires initial D2D copy for operand " << i;
2882 thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
2883 Thunk::ThunkInfo(),
2884 /*source_address=*/source_address,
2885 /*destination_buffer=*/destination_buffer,
2886 /*mem_size=*/ShapeUtil::ByteSizeOf(context.operand_shapes[i])));
2887 }
2888 }
2889
2890 uint64 dimension_to_sort_bound = keys_shape.dimensions(dimension_to_sort);
2891 int64 num_stages = tensorflow::Log2Ceiling(dimension_to_sort_bound);
2892 VLOG(2) << context.name << " requires " << num_stages << " stages.";
2893 CHECK_GE(1ULL << num_stages, dimension_to_sort_bound);
2894 CHECK_LT(1ULL << (num_stages - 1), dimension_to_sort_bound);
2895
2896 // Naive C++ code for the outer loops:
2897 //
2898 // for (int64 stage = 0; stage < Log2Ceiling(dimension_to_sort_bound);
2899 // ++stage) {
2900 // int64 first_xor_mask = (1LL << (stage + 1)) - 1;
2901 // SortInPlace(first_xor_mask);
2902 // for (int64 mask = stage - 1; mask >= 0; --mask) {
2903 // int64 later_xor_mask = 1LL << mask;
2904 // SortInPlace(later_xor_mask);
2905 // }
2906 // }
2907 //
2908 // This follows the alternative representation of the algorithm described on
2909 // Wikipedia: https://en.wikipedia.org/wiki/Bitonic_sorter
2910 //
2911 // Each mask specifies how to derive from one position in the array the
2912 // position with which it should be compared (we calculate the xor of the
2913 // position with the mask).
2914 // As an optimization, we can move the 'mask' loop to inside the
2915 // sorting/comparison loop if the comparisons happen within a small block of
2916 // the array. To make this work, we collect all consecutive masks that are
2917 // smaller than our chosen power of 2 tile size, and pass them to SortInPlace.
2918 // Each thread then processes one tile of data.
2919
2920 const uint64 kTileSize = std::min(2048ULL, 1ULL << num_stages);
2921
2922 // If we cannot combine several xor masks together, we don't use tiling, so we
2923 // calculate the standard launch dimensions for the shape. However we only
2924 // need to iterate through ~half of the dimension to sort (rounded up to the
2925 // next highest power of 2), because each iteration compares one pair of
2926 // elements.
2927 Shape standard_iteration_shape = keys_shape;
2928 uint64 standard_num_iterations_in_sort_dim = 1ULL << (num_stages - 1);
2929 standard_iteration_shape.set_dimensions(dimension_to_sort,
2930 standard_num_iterations_in_sort_dim);
2931 LaunchDimensions standard_launch_dimensions = CalculateLaunchDimensions(
2932 standard_iteration_shape, ir_emitter_context_->gpu_device_info());
2933
2934 // Calculate the launch dimensions for the case where we use tiling. We split
2935 // the dimension that should be sorted into tiles of size 'kTileSize'. This
2936 // means we first need to round 'dimension_to_sort_bound' up to be a multiple
2937 // of the tile size.
2938 int64 rounded_bound = RoundUpToNearest(dimension_to_sort_bound, kTileSize);
2939 Shape iteration_shape = keys_shape;
2940
2941 // We iterate through the element pairs that should be compared.
2942 uint64 num_iterations_in_sort_dim = rounded_bound / 2;
2943 iteration_shape.set_dimensions(dimension_to_sort, num_iterations_in_sort_dim);
2944 uint64 num_iterations = ShapeUtil::ElementsIn(iteration_shape);
2945
2946 // For correctness reasons we need exactly 'kTileSize' / 2 many threads per
2947 // block. Each thread is responsible for copying exactly two adjacent elements
2948 // into shared memory, and then does a comparison of two possibly different
2949 // elements taken from shared memory.
2950 const uint64 kThreadsPerBlock = kTileSize / 2;
2951
2952 // Check whether we should use any tiling. We might not be able to use it if
2953 // we have not enough threads, or not enough shared memory. Also it does not
2954 // give a speedup if the tile size is < 128.
2955 int64 total_shared_memory_needed = 0;
2956 for (int64 i = 0; i < context.operand_shapes.size(); ++i) {
2957 total_shared_memory_needed +=
2958 kTileSize * ShapeUtil::ByteSizeOfPrimitiveType(
2959 context.operand_shapes[i].element_type());
2960 }
2961 bool no_tiling =
2962 kTileSize < 128 ||
2963 kThreadsPerBlock >
2964 ir_emitter_context_->gpu_device_info().threads_per_block_limit ||
2965 total_shared_memory_needed >
2966 ir_emitter_context_->gpu_device_info().shared_memory_per_block;
2967 VLOG(2) << absl::StreamFormat(
2968 "%s %s use tiling. No tiling if any of the following is true: "
2969 "kTileSize=%d < 128, "
2970 "kThreadsPerBlock=%d > threads_per_block_limit=%d, "
2971 "total_shared_memory_needed=%d > shared_memory_per_block=%d",
2972 context.name, (no_tiling ? "won't" : "will"), kTileSize, kThreadsPerBlock,
2973 ir_emitter_context_->gpu_device_info().threads_per_block_limit,
2974 total_shared_memory_needed,
2975 ir_emitter_context_->gpu_device_info().shared_memory_per_block);
2976
2977 uint64 num_blocks = CeilOfRatio(num_iterations, kThreadsPerBlock);
2978 LaunchDimensions tiled_launch_dimensions(num_blocks, kThreadsPerBlock);
2979 VLOG(2) << absl::StreamFormat("%s launch dims: %d blocks, %d threads/block",
2980 context.name, num_blocks, kThreadsPerBlock);
2981
2982 std::vector<llvm_ir::IrArray> ir_arrays;
2983 auto emit_kernel = [&](absl::Span<const int64> xor_masks) {
2984 VLOG(2) << absl::StreamFormat(
2985 "%s uses kernel for xor masks [%s]", context.name,
2986 absl::StrJoin(xor_masks, ", ", [](std::string* out, int64 xor_mask) {
2987 absl::StrAppendFormat(out, "0x%x", xor_mask);
2988 }));
2989 thunks.emplace_back();
2990 TF_ASSIGN_OR_RETURN(
2991 thunks.back(),
2992 BuildKernelThunkForMlir(sort_op, sort_op.output(), Thunk::ThunkInfo(),
2993 mlir_input.extra_slice, &ir_arrays));
2994 LaunchDimensions launch_dimensions = xor_masks.size() > 1
2995 ? tiled_launch_dimensions
2996 : standard_launch_dimensions;
2997 UpdateLaunchDimensions(launch_dimensions, thunks.back().get(),
2998 ir_emitter_context_->llvm_module());
2999 std::vector<IrArray> values_arrays;
3000 values_arrays.reserve(context.operand_shapes.size());
3001 for (int64 i = 0; i < context.operand_shapes.size(); ++i) {
3002 values_arrays.push_back(ir_arrays[i]);
3003 }
3004 TF_ASSIGN_OR_RETURN(const HloComputation* comparator,
3005 GetOrCreateSubComputationFromRegion(
3006 &sort_op.comparator(), /*is_fusion=*/false));
3007 return llvm_ir::EmitSortInPlace(
3008 dimension_to_sort, values_arrays, IrName(context.name), xor_masks, &b_,
3009 launch_dimensions,
3010 xor_masks.size() > 1 ? num_iterations_in_sort_dim
3011 : standard_num_iterations_in_sort_dim,
3012 kTileSize,
3013 [&](absl::Span<llvm::Value* const> operands, llvm::Value* output) {
3014 return EmitCallToNestedComputation(*comparator, operands, output);
3015 });
3016 };
3017 std::vector<int64> xor_masks;
3018 for (int64 stage = 0; stage < num_stages; ++stage) {
3019 for (int64 mask = stage; mask >= 0; --mask) {
3020 int64 xor_mask;
3021 if (mask == stage) {
3022 xor_mask = (1LL << (stage + 1)) - 1;
3023 } else {
3024 xor_mask = 1LL << mask;
3025 }
3026 if (xor_mask >= kTileSize || no_tiling) {
3027 if (!xor_masks.empty()) {
3028 TF_RETURN_IF_ERROR(emit_kernel(xor_masks));
3029 xor_masks.clear();
3030 }
3031 TF_RETURN_IF_ERROR(emit_kernel({xor_mask}));
3032 } else {
3033 xor_masks.push_back(xor_mask);
3034 }
3035 }
3036 }
3037 if (!xor_masks.empty()) {
3038 TF_RETURN_IF_ERROR(emit_kernel(xor_masks));
3039 }
3040 VLOG(2) << absl::StreamFormat(
3041 "%s requires %d thunks (including any D2D copies)", context.name,
3042 thunks.size());
3043
3044 AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
3045 mlir_input.thunk_info, std::move(thunks)));
3046 return Status::OK();
3047 }
3048
3049 template <typename ThunkType, typename OpT>
EmitReplicaOrPartitionIdFromMlir(MlirEmitterInput input)3050 Status IrEmitterUnnested::EmitReplicaOrPartitionIdFromMlir(
3051 MlirEmitterInput input) {
3052 auto op = mlir::cast<OpT>(input.op);
3053 TF_ASSIGN_OR_RETURN(BufferAllocation::Slice result_slice,
3054 GetAllocationSliceForMlir(op.getOperand()));
3055 AddThunkToThunkSequence(
3056 absl::make_unique<ThunkType>(input.thunk_info, result_slice));
3057 return Status::OK();
3058 }
3059
HandleReplicaId(HloInstruction * hlo)3060 Status IrEmitterUnnested::HandleReplicaId(HloInstruction* hlo) {
3061 TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
3062 return EmitReplicaOrPartitionIdFromMlir<ReplicaIdThunk,
3063 mlir::lmhlo::ReplicaIdOp>(input);
3064 }
3065
HandlePartitionId(HloInstruction * hlo)3066 Status IrEmitterUnnested::HandlePartitionId(HloInstruction* hlo) {
3067 TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
3068 return EmitReplicaOrPartitionIdFromMlir<PartitionIdThunk,
3069 mlir::lmhlo::PartitionIdOp>(input);
3070 }
3071
HandleCollectivePermute(HloInstruction * hlo)3072 Status IrEmitterUnnested::HandleCollectivePermute(HloInstruction* hlo) {
3073 TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
3074 return EmitCollectivePermuteFromMlir(input);
3075 }
3076
EmitCollectivePermuteFromMlir(MlirEmitterInput input)3077 Status IrEmitterUnnested::EmitCollectivePermuteFromMlir(
3078 MlirEmitterInput input) {
3079 auto collective_permute_op =
3080 mlir::cast<mlir::lmhlo::CollectivePermuteOp>(input.op);
3081 if (collective_permute_op.channel_id())
3082 return Unimplemented("collective permute with channel_id not implemented");
3083 using source_dest_pairs_t = std::vector<std::pair<int64, int64>>;
3084 TF_ASSIGN_OR_RETURN(
3085 source_dest_pairs_t source_dest_pairs,
3086 ConvertNx2Attribute(collective_permute_op.source_target_pairs()));
3087
3088 TF_ASSIGN_OR_RETURN(
3089 BufferAllocation::Slice source_slice,
3090 GetAllocationSliceForMlir(collective_permute_op.operand()));
3091 TF_ASSIGN_OR_RETURN(
3092 BufferAllocation::Slice result_slice,
3093 GetAllocationSliceForMlir(collective_permute_op.output()));
3094
3095 AddThunkToThunkSequence(absl::make_unique<CollectivePermuteThunk>(
3096 input.thunk_info, std::move(source_dest_pairs), source_slice,
3097 result_slice));
3098 return Status::OK();
3099 }
3100
3101 template <typename NcclThunkType, typename OpTy>
EmitNcclThunkFromMlir(MlirEmitterInput input)3102 Status IrEmitterUnnested::EmitNcclThunkFromMlir(MlirEmitterInput input) {
3103 OpTy op = mlir::cast<OpTy>(input.op);
3104 int64 replica_count = hlo_module_config_.replica_count();
3105 VLOG(2) << NcclThunkType::GetName() << "; replica count: " << replica_count
3106 << "; operand count: " << op.operands().size()
3107 << "; NCCL is enabled: " << NcclThunkType::NcclIsEnabled();
3108
3109 // Note the replica_count == 1 case is handled via device-to-device copy
3110 // below.
3111 bool should_use_nccl_thunk =
3112 replica_count > 1 && NcclThunkType::CanImplement(op);
3113
3114 // Stash relevant information in NcclAllGatherThunk::Buffer even if we may
3115 // not generate an NcclAllGatherThunk.
3116 std::vector<NcclCollectiveThunk::Buffer> buffers;
3117 buffers.reserve(op.operands().size());
3118 for (auto it : llvm::zip(op.operands(), op.results())) {
3119 mlir::Value operand = std::get<0>(it);
3120 mlir::Value result = std::get<1>(it);
3121 const Shape shape = TypeToShape(operand.getType());
3122 TF_ASSIGN_OR_RETURN(auto source_slice, GetAllocationSliceForMlir(operand));
3123 TF_ASSIGN_OR_RETURN(auto dest_slice, GetAllocationSliceForMlir(result));
3124 buffers.push_back(NcclCollectiveThunk::Buffer{
3125 /*element_count=*/ShapeUtil::ElementsIn(shape),
3126 /*source_buffer=*/source_slice,
3127 /*destination_buffer=*/dest_slice});
3128 }
3129
3130 if (should_use_nccl_thunk) {
3131 auto nccl_thunk =
3132 absl::make_unique<NcclThunkType>(input.thunk_info, op, replica_count,
3133 /*buffers=*/std::move(buffers));
3134 AddThunkToThunkSequence(std::move(nccl_thunk));
3135 return Status::OK();
3136 }
3137
3138 if (replica_count != 1) {
3139 string message = absl::StrFormat(
3140 "Requested %s not implemented on GPU; replica_count: %d; "
3141 "operand_count: %d; NCCL support: %d",
3142 NcclThunkType::GetName(), replica_count, op.operands().size(),
3143 NcclThunkType::NcclIsEnabled());
3144 if (!op.operands().empty()) {
3145 const Shape shape = TypeToShape(op.operands().front().getType());
3146 absl::StrAppendFormat(&message, "; first operand array element-type: %s",
3147 PrimitiveType_Name(shape.element_type()));
3148 }
3149 return Unimplemented("%s", message);
3150 }
3151
3152 // All-gather with one replica is simply the identity function. Buffer
3153 // assignment expects a copy, so that's what we do.
3154 std::vector<std::unique_ptr<Thunk>> thunks;
3155 for (int64 i = 0; i < buffers.size(); i++) {
3156 const Shape shape = TypeToShape(op.operands()[i].getType());
3157 thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
3158 buffers.size() == 1 ? input.thunk_info : Thunk::ThunkInfo(),
3159 /*source_address=*/buffers[i].source_buffer,
3160 /*destination_buffer=*/buffers[i].destination_buffer,
3161 /*mem_size=*/ShapeUtil::ByteSizeOf(shape)));
3162 }
3163 if (thunks.size() == 1) {
3164 AddThunkToThunkSequence(std::move(thunks[0]));
3165 } else {
3166 AddThunkToThunkSequence(absl::make_unique<SequentialThunk>(
3167 input.thunk_info, std::move(thunks)));
3168 }
3169 return Status::OK();
3170 }
3171
HandleAllGather(HloInstruction * hlo)3172 Status IrEmitterUnnested::HandleAllGather(HloInstruction* hlo) {
3173 TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
3174 return EmitNcclThunkFromMlir<NcclAllGatherThunk, mlir::lmhlo::AllGatherOp>(
3175 input);
3176 }
3177
HandleAllReduce(HloInstruction * hlo)3178 Status IrEmitterUnnested::HandleAllReduce(HloInstruction* hlo) {
3179 TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
3180 return EmitNcclThunkFromMlir<NcclAllReduceThunk, mlir::lmhlo::AllReduceOp>(
3181 input);
3182 }
3183
HandleAllToAll(HloInstruction * hlo)3184 Status IrEmitterUnnested::HandleAllToAll(HloInstruction* hlo) {
3185 TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(hlo));
3186 return EmitNcclThunkFromMlir<NcclAllToAllThunk, mlir::lmhlo::AllToAllOp>(
3187 input);
3188 }
3189
HandleInfeed(HloInstruction * xla_infeed)3190 Status IrEmitterUnnested::HandleInfeed(HloInstruction* xla_infeed) {
3191 TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(xla_infeed));
3192
3193 auto infeed_op = mlir::cast<mlir::lmhlo::InfeedOp>(input.op);
3194
3195 std::vector<ShapedSlice> dest_slices;
3196 dest_slices.reserve(infeed_op.outputs().size());
3197
3198 for (mlir::Value output : infeed_op.outputs()) {
3199 TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSliceForMlir(output));
3200 const Shape& shape = TypeToShape(output.getType());
3201 dest_slices.push_back(ShapedSlice{slice, shape});
3202 }
3203
3204 AddThunkToThunkSequence(
3205 absl::make_unique<InfeedThunk>(input.thunk_info, std::move(dest_slices)));
3206 return Status::OK();
3207 }
3208
HandleOutfeed(HloInstruction * outfeed)3209 Status IrEmitterUnnested::HandleOutfeed(HloInstruction* outfeed) {
3210 TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(outfeed));
3211
3212 auto outfeed_op = mlir::cast<mlir::lmhlo::OutfeedOp>(input.op);
3213
3214 std::vector<ShapedSlice> source_slices;
3215 source_slices.reserve(outfeed_op.operands().size());
3216
3217 for (mlir::Value operand : outfeed_op.operands()) {
3218 TF_ASSIGN_OR_RETURN(auto slice, GetAllocationSliceForMlir(operand));
3219 const Shape& shape = TypeToShape(operand.getType());
3220 source_slices.push_back(ShapedSlice{slice, shape});
3221 }
3222
3223 AddThunkToThunkSequence(absl::make_unique<OutfeedThunk>(
3224 input.thunk_info, std::move(source_slices)));
3225 return Status::OK();
3226 }
3227
HandleAfterAll(HloInstruction * after_all)3228 Status IrEmitterUnnested::HandleAfterAll(HloInstruction* after_all) {
3229 return Status::OK();
3230 }
3231
3232 // Figures out how to access the buffers for all subshapes of hlo's operands and
3233 // for hlo itself (i.e. all the buffers produced by HLO).
3234 //
3235 // Returns a vector of `HloBufferSlice`s, one for each HLO subshape `hlo` needs
3236 // to access (including one or more for itself).
3237 //
3238 // This function conservatively assumes that we'll touch all sub-buffers of
3239 // every operand and of the output.
GetHloBufferSlices(const HloInstruction * hlo,const BufferAssignment & buffer_assn)3240 static std::vector<HloBufferSlice> GetHloBufferSlices(
3241 const HloInstruction* hlo, const BufferAssignment& buffer_assn) {
3242 std::vector<HloBufferSlice> result;
3243 absl::flat_hash_set<std::pair<const HloInstruction*, ShapeIndex>>
3244 inserted_buffer_slices;
3245
3246 // Tries to find a slice plus an array of indices i1, ..., iN such that the
3247 // sub-buffer for instr at index can be found at slice[i1]...[iN].
3248 auto find_slice_for = [&](const HloInstruction* instr,
3249 const ShapeIndex& index)
3250 -> optional<std::pair<BufferAllocation::Slice, ShapeIndex>> {
3251 // Simple, common case: Is the buffer for instr known at runtime? If so,
3252 // we're done.
3253 auto slice = buffer_assn.GetUniqueSlice(instr, index);
3254 if (slice.ok()) {
3255 return {{slice.ValueOrDie(), ShapeIndex()}};
3256 }
3257
3258 // If that didn't work, walk up any bitcasts that we might see. These must
3259 // appear before any GTE instructions, because it's illegal to bitcast to a
3260 // tuple type.
3261 const HloInstruction* parent = instr;
3262 while (parent->IsEffectiveBitcast()) {
3263 parent = parent->operand(0);
3264
3265 auto slice = buffer_assn.GetUniqueSlice(parent, {});
3266 if (slice.ok()) {
3267 return {{slice.ValueOrDie(), ShapeIndex()}};
3268 }
3269 }
3270
3271 // Check whether instr is a GTE instruction. If it is, see if we can get a
3272 // buffer for its parent, and continue walking up parents until we find a
3273 // defined buffer or we hit something that's not a GTE.
3274 ShapeIndex gte_indices;
3275 while (parent->opcode() == HloOpcode::kGetTupleElement) {
3276 gte_indices.push_front(parent->tuple_index());
3277 parent = parent->operand(0);
3278
3279 auto slice = buffer_assn.GetUniqueSlice(parent, {});
3280 if (slice.ok()) {
3281 return {{slice.ValueOrDie(), gte_indices}};
3282 }
3283 }
3284
3285 // Finally, if we don't know the buffer for instr at index, see if we know
3286 // the buffer for instr at index without its last element. If so, we can
3287 // dynamically find the buffer for instr by dereferencing a pointer in that
3288 // buffer. Continue looking this way until we run out of elements in
3289 // 'index'.
3290 //
3291 // We can almost always get a buffer without resorting to this. The only
3292 // exception is for cases where the relevant sub-buffer is truly unknowable,
3293 // for example the sub-buffer of a tuple-shaped select.
3294 ShapeIndex new_index = index;
3295 while (!new_index.empty()) {
3296 gte_indices.push_front(new_index.back());
3297 new_index.pop_back();
3298 auto slice = buffer_assn.GetUniqueSlice(instr, new_index);
3299 if (slice.ok()) {
3300 return {{slice.ValueOrDie(), gte_indices}};
3301 }
3302 }
3303
3304 return nullopt;
3305 };
3306
3307 // Adds entries for all subshapes of instr to `slices`.
3308 auto add_slices_for = [&](const HloInstruction* instr) {
3309 ShapeUtil::ForEachSubshape(
3310 instr->shape(), [&](const Shape& /*shape*/, const ShapeIndex& index) {
3311 if (!inserted_buffer_slices.insert({instr, index}).second) {
3312 // HLOs can have duplicate operands; don't bother redoing work.
3313 return;
3314 }
3315 auto maybe_slice = find_slice_for(instr, index);
3316 if (maybe_slice.has_value()) {
3317 HloBufferSlice hlo_buffer_slice;
3318 hlo_buffer_slice.instr = instr;
3319 hlo_buffer_slice.hlo_index = index;
3320 hlo_buffer_slice.buffer_slice = maybe_slice->first;
3321 hlo_buffer_slice.gte_index = maybe_slice->second;
3322 result.push_back(hlo_buffer_slice);
3323 } else {
3324 VLOG(1) << "Couldn't find buffer for " << instr->ToString()
3325 << " at index " << index.ToString();
3326 }
3327 });
3328 };
3329
3330 add_slices_for(hlo);
3331 for (const HloInstruction* operand : hlo->operands()) {
3332 // Conservatively assume we'll need the buffers for all subshapes of the
3333 // operand.
3334 add_slices_for(operand);
3335 }
3336
3337 return result;
3338 }
3339
3340 std::unique_ptr<KernelThunk>
BuildKernelThunkFromBufferSlices(absl::string_view name,Thunk::ThunkInfo thunk_info,absl::Span<const BufferSlice * const> slices,std::function<void (const BufferSlice *,llvm::Value *)> bind_slice_to_ir_value)3341 IrEmitterUnnested::BuildKernelThunkFromBufferSlices(
3342 absl::string_view name, Thunk::ThunkInfo thunk_info,
3343 absl::Span<const BufferSlice* const> slices,
3344 std::function<void(const BufferSlice*, llvm::Value*)>
3345 bind_slice_to_ir_value) {
3346 // Figure out which buffer allocations need to be passed as arguments to our
3347 // kernel. This is simply all of the allocations referenced in slices,
3348 // plus the XLA temp buffer (if we have it). We always include the temp
3349 // buffer because even if the kernel itself doesn't use it, a nested
3350 // subcomputation within the kernel (e.g. a kMap's computation) might.
3351 std::unordered_set<const BufferAllocation*> buffers_needed;
3352 for (auto* slice : slices) {
3353 buffers_needed.insert(slice->buffer_slice.allocation());
3354 }
3355 absl::optional<const BufferAllocation*> temp_buffer;
3356 for (const BufferAllocation& alloc : ir_emitter_context_->allocations()) {
3357 if (alloc.IsPreallocatedTempBuffer()) {
3358 if (!temp_buffer.has_value()) {
3359 // Retrieve the first seen temp buffer.
3360 temp_buffer = &alloc;
3361 }
3362 }
3363 }
3364 if (temp_buffer.has_value()) {
3365 buffers_needed.insert(*temp_buffer);
3366 }
3367
3368 // We'll pass a pointer to each of the elements of `buffers` to our kernel, in
3369 // this order.
3370 std::vector<const BufferAllocation*> non_constant_buffers;
3371 absl::c_copy_if(buffers_needed, std::back_inserter(non_constant_buffers),
3372 [](const BufferAllocation* allocation) {
3373 return !allocation->is_constant();
3374 });
3375
3376 absl::c_sort(non_constant_buffers,
3377 [](const BufferAllocation* a, const BufferAllocation* b) {
3378 return a->index() < b->index();
3379 });
3380
3381 llvm::Function* kernel = BuildKernelPrototype(name, non_constant_buffers);
3382
3383 // Build a map from a BufferAllocation to the corresponding argument in our
3384 // kernel.
3385 std::unordered_map<const BufferAllocation*, llvm::Value*> kernel_args;
3386 {
3387 auto arg_it = kernel->arg_begin();
3388 auto buffers_it = non_constant_buffers.begin();
3389 for (; arg_it != kernel->arg_end(); ++arg_it, ++buffers_it) {
3390 kernel_args[*buffers_it] = arg_it;
3391
3392 // Annotate all allocations with LLVM's `noalias`.
3393 // There are three kinds of allocations:
3394 // * Read-only allocations, aka input parameters that are not aliased with
3395 // outputs.
3396 // * Read-write allocations, including all output buffers, some of which
3397 // may alias with input HLO parameters, but aliased HLO buffers are always
3398 // assigned with the same allocation.
3399 // * The temp buffer.
3400 //
3401 // Read-only allocations may overlap with each other, but since they are
3402 // not mutated, they can always be annotated with `noalias` per LLVM
3403 // semantics.
3404 //
3405 // Read-write allocations and the temp buffer don't overlap with any
3406 // allocations, therefore they can also be annotated with `noalias`.
3407 kernel->addParamAttr(
3408 arg_it->getArgNo(),
3409 llvm::Attribute::get(arg_it->getContext(), llvm::Attribute::NoAlias));
3410 }
3411 }
3412
3413 // For each buffer our kernel might want to touch, bind it to a value derived
3414 // from our kernel args.
3415 for (auto* slice : slices) {
3416 const BufferAllocation::Slice& buffer_slice = slice->buffer_slice;
3417 const ShapeIndex& gte_index = slice->gte_index;
3418
3419 llvm::Value* loc;
3420 if (buffer_slice.allocation()->is_constant()) {
3421 loc = ir_emitter_context_->llvm_module()->getGlobalVariable(
3422 llvm_ir::ConstantBufferAllocationToGlobalName(
3423 *buffer_slice.allocation()));
3424 CHECK_NE(loc, nullptr);
3425 } else {
3426 loc = InBoundsGEP(kernel_args.at(buffer_slice.allocation()),
3427 {b_.getInt64(buffer_slice.offset())});
3428 }
3429
3430 // If gte_index is nonempty, we have to dereference `loc` to get to the
3431 // value we're ultimately interested in.
3432 llvm::Type* int8_double_pointer =
3433 llvm::PointerType::get(b_.getInt8PtrTy(), /*AddressSpace=*/0);
3434 for (int64 idx : gte_index) {
3435 loc = b_.CreatePointerBitCastOrAddrSpaceCast(loc, int8_double_pointer);
3436 loc = Load(InBoundsGEP(loc, {b_.getInt64(idx)}));
3437 }
3438
3439 bind_slice_to_ir_value(slice, loc);
3440 }
3441
3442 // Bind the temp buffer so that nested subcomputations can find it if they
3443 // need.
3444 if (temp_buffer.has_value()) {
3445 bindings_.SetTempBufferBase(kernel_args.at(*temp_buffer));
3446 } else {
3447 bindings_.SetTempBufferBase(
3448 llvm::ConstantPointerNull::get(b_.getInt8PtrTy()));
3449 }
3450
3451 return absl::make_unique<KernelThunk>(thunk_info, non_constant_buffers,
3452 std::string(kernel->getName()));
3453 }
3454
BuildKernelThunk(const HloInstruction * inst,bool implements_whole_instruction)3455 std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunk(
3456 const HloInstruction* inst, bool implements_whole_instruction) {
3457 std::vector<HloBufferSlice> hlo_slices =
3458 GetHloBufferSlices(inst, ir_emitter_context_->buffer_assignment());
3459
3460 std::vector<BufferSlice*> slice_ptrs;
3461 slice_ptrs.reserve(hlo_slices.size());
3462 for (auto& slice : hlo_slices) {
3463 slice_ptrs.push_back(&slice);
3464 }
3465
3466 return BuildKernelThunkFromBufferSlices(
3467 inst->name(),
3468 implements_whole_instruction ? GetThunkInfo(inst) : Thunk::ThunkInfo(),
3469 slice_ptrs, [this](const BufferSlice* slice, llvm::Value* value) {
3470 const HloBufferSlice* hlo_buffer_slice =
3471 static_cast<const HloBufferSlice*>(slice);
3472 const HloInstruction* instr = hlo_buffer_slice->instr;
3473 const ShapeIndex& index = hlo_buffer_slice->hlo_index;
3474 VLOG(3) << "Buffer for " << instr->ToString() << " at "
3475 << index.ToString() << " is found in slice "
3476 << hlo_buffer_slice->buffer_slice.ToString() << " at GTE index "
3477 << hlo_buffer_slice->gte_index.ToString();
3478
3479 bindings_.BindHloToIrValue(*instr, value, index);
3480 });
3481 }
3482
BuildKernelThunkForMlirImpl(absl::string_view name,Thunk::ThunkInfo thunk_info,absl::Span<const MlirBufferSlice> slices,std::vector<llvm_ir::IrArray> * ir_arrays)3483 std::unique_ptr<KernelThunk> IrEmitterUnnested::BuildKernelThunkForMlirImpl(
3484 absl::string_view name, Thunk::ThunkInfo thunk_info,
3485 absl::Span<const MlirBufferSlice> slices,
3486 std::vector<llvm_ir::IrArray>* ir_arrays) {
3487 absl::flat_hash_set<BufferAllocation::Slice> buffers_written;
3488 std::vector<const BufferSlice*> slice_ptrs;
3489 slice_ptrs.reserve(slices.size());
3490 for (auto& slice : slices) {
3491 slice_ptrs.push_back(&slice);
3492 if (slice.written) {
3493 buffers_written.insert(slice.buffer_slice);
3494 }
3495 }
3496
3497 ir_arrays->clear();
3498 return BuildKernelThunkFromBufferSlices(
3499 name, thunk_info, slice_ptrs,
3500 [&](const BufferSlice* slice, llvm::Value* value) {
3501 const auto& mlir_slice = static_cast<const MlirBufferSlice&>(*slice);
3502
3503 llvm_ir::IrArray ir_array(
3504 CastToTypedValue(mlir_slice.shape, value, &b_), mlir_slice.shape);
3505 if (!buffers_written.contains(slice->buffer_slice)) {
3506 ir_array.MarkInvariantOverWholeProgram(&value->getContext());
3507 }
3508
3509 ir_arrays->push_back(ir_array);
3510 });
3511 }
3512
3513 StatusOr<std::unique_ptr<KernelThunk>>
BuildKernelThunkForMlir(mlir::Operation * op,mlir::ValueRange operands,Thunk::ThunkInfo thunk_info,absl::optional<MlirBufferSlice> extra_slice,std::vector<llvm_ir::IrArray> * ir_arrays)3514 IrEmitterUnnested::BuildKernelThunkForMlir(
3515 mlir::Operation* op, mlir::ValueRange operands, Thunk::ThunkInfo thunk_info,
3516 absl::optional<MlirBufferSlice> extra_slice,
3517 std::vector<llvm_ir::IrArray>* ir_arrays) {
3518 TF_RET_CHECK(!mlir::isa<mlir::lmhlo::FusionOp>(op));
3519
3520 std::vector<MlirBufferSlice> slices;
3521 for (mlir::Value operand : operands) {
3522 slices.emplace_back();
3523 auto& slice = slices.back();
3524 TF_ASSIGN_OR_RETURN(slice.buffer_slice, GetAllocationSliceForMlir(operand));
3525 slice.written = WritesMlirBuffer(op, operand);
3526 slice.shape = TypeToShape(operand.getType());
3527 }
3528 if (extra_slice) {
3529 slices.push_back(*extra_slice);
3530 }
3531 std::string name = mlir::GetNameFromLoc(op->getLoc());
3532 return BuildKernelThunkForMlirImpl(name, thunk_info, slices, ir_arrays);
3533 }
3534
3535 StatusOr<std::unique_ptr<KernelThunk>>
BuildKernelThunkForMlir(mlir::Operation * op,Thunk::ThunkInfo thunk_info,absl::optional<MlirBufferSlice> extra_slice,std::vector<llvm_ir::IrArray> * ir_arrays)3536 IrEmitterUnnested::BuildKernelThunkForMlir(
3537 mlir::Operation* op, Thunk::ThunkInfo thunk_info,
3538 absl::optional<MlirBufferSlice> extra_slice,
3539 std::vector<llvm_ir::IrArray>* ir_arrays) {
3540 if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
3541 auto operands = GetHloOperands(op);
3542 auto outputs = GetHloOutputs(op);
3543
3544 std::vector<MlirBufferSlice> slices;
3545 for (auto operand : operands) {
3546 slices.emplace_back();
3547 auto& slice = slices.back();
3548 TF_ASSIGN_OR_RETURN(slice.buffer_slice,
3549 GetAllocationSliceForMlir(operand));
3550 slice.written = false;
3551 slice.shape = TypeToShape(operand.getType());
3552 }
3553 for (auto output : outputs) {
3554 slices.emplace_back();
3555 auto& slice = slices.back();
3556 TF_ASSIGN_OR_RETURN(slice.buffer_slice,
3557 GetAllocationSliceForMlir(output));
3558 slice.written = true;
3559 slice.shape = TypeToShape(output.getType());
3560 }
3561 std::string name = mlir::GetNameFromLoc(op->getLoc());
3562 if (extra_slice) {
3563 slices.push_back(*extra_slice);
3564 }
3565 return BuildKernelThunkForMlirImpl(name, thunk_info, slices, ir_arrays);
3566 }
3567 return BuildKernelThunkForMlir(op, op->getOperands(), thunk_info, extra_slice,
3568 ir_arrays);
3569 }
3570
BuildConstantInitializerThunk(absl::Span<const uint8> init_value,const BufferAllocation::Slice & dest,const Shape & output_shape)3571 std::unique_ptr<Thunk> IrEmitterUnnested::BuildConstantInitializerThunk(
3572 absl::Span<const uint8> init_value, const BufferAllocation::Slice& dest,
3573 const Shape& output_shape) {
3574 int64 num_bytes = init_value.size();
3575 if (absl::c_all_of(init_value, [](uint8 byte) { return byte == 0; })) {
3576 return absl::make_unique<MemzeroThunk>(Thunk::ThunkInfo(), dest);
3577 }
3578
3579 // If the literal is 8 or 16 bits wide, we can emit a 32-bit memset by
3580 // repeating the literal 4 or 2 times, so long as the destination buffer is
3581 // an even multiple of 32 bits long.
3582 if ((num_bytes == 1 || num_bytes == 2) &&
3583 ShapeUtil::ByteSizeOf(output_shape) % 4 == 0) {
3584 uint16 pattern16;
3585 if (num_bytes == 1) {
3586 uint8 b = init_value.front();
3587 pattern16 = uint16{b} | (uint16{b} << 8);
3588 } else {
3589 memcpy(&pattern16, init_value.data(), sizeof(pattern16));
3590 }
3591 uint32 pattern32 = uint32{pattern16} | (uint32{pattern16} << 16);
3592 return absl::make_unique<Memset32BitValueThunk>(Thunk::ThunkInfo(),
3593 pattern32, dest);
3594 }
3595
3596 // If the literal is an even multiple of 32 bits wide, we can emit a 32-bit
3597 // memset so long as all 32-bit words of the scalar are equal to each other.
3598 if (num_bytes >= 4 && num_bytes % 4 == 0 &&
3599 memcmp(init_value.data(), init_value.data() + 4, init_value.size() - 4) ==
3600 0) {
3601 uint32 word;
3602 memcpy(&word, init_value.data(), sizeof(word));
3603 return absl::make_unique<Memset32BitValueThunk>(Thunk::ThunkInfo(), word,
3604 dest);
3605 }
3606
3607 return nullptr;
3608 }
3609
3610 StatusOr<std::unique_ptr<Thunk>>
TryBuildConstantInitializerThunk(mlir::Value init_value,mlir::Value dest)3611 IrEmitterUnnested::TryBuildConstantInitializerThunk(mlir::Value init_value,
3612 mlir::Value dest) {
3613 mlir::DenseElementsAttr const_init;
3614 if (auto get_global_memref = mlir::dyn_cast_or_null<mlir::GetGlobalMemrefOp>(
3615 init_value.getDefiningOp())) {
3616 auto global_memref =
3617 mlir::SymbolTable::lookupNearestSymbolFrom<mlir::GlobalMemrefOp>(
3618 get_global_memref, get_global_memref.name());
3619 if (global_memref.constant() && global_memref.initial_value()) {
3620 // If the initial value happens to be a constant, generate a specialized
3621 // thunk.
3622 const_init = global_memref.initial_value()
3623 .getValue()
3624 .cast<mlir::DenseElementsAttr>();
3625 }
3626 } else if (auto constant = mlir::dyn_cast_or_null<mlir::mhlo::ConstOp>(
3627 init_value.getDefiningOp())) {
3628 const_init = constant.value().dyn_cast<mlir::DenseElementsAttr>();
3629 }
3630
3631 if (const_init) {
3632 std::vector<uint8> literal_bytes;
3633 TF_RETURN_IF_ERROR(
3634 CopyDenseElementsDataToXlaFormat(const_init, &literal_bytes));
3635
3636 TF_ASSIGN_OR_RETURN(auto dest_slice, GetAllocationSliceForMlir(dest));
3637
3638 const Shape dest_shape = TypeToShape(dest.getType());
3639 auto thunk =
3640 BuildConstantInitializerThunk(literal_bytes, dest_slice, dest_shape);
3641 if (thunk) {
3642 return {std::move(thunk)};
3643 }
3644 }
3645 return std::unique_ptr<Thunk>();
3646 }
3647
3648 StatusOr<std::unique_ptr<Thunk>>
BuildInitializerThunkForMlir(mlir::Operation * op,mlir::Value init_value,mlir::Value dest)3649 IrEmitterUnnested::BuildInitializerThunkForMlir(mlir::Operation* op,
3650 mlir::Value init_value,
3651 mlir::Value dest) {
3652 // initial value must be a scalar memref.
3653 auto init_type = init_value.getType().dyn_cast<mlir::MemRefType>();
3654 TF_RET_CHECK(init_type.getRank() == 0);
3655
3656 TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> constant_init_thunk,
3657 TryBuildConstantInitializerThunk(init_value, dest));
3658 if (constant_init_thunk) {
3659 return {std::move(constant_init_thunk)};
3660 }
3661
3662 // Otherwise fall back to our slow initializer code. The thunk in this case
3663 // will just need the IR arrays for the initial value and the destination.
3664 std::vector<llvm_ir::IrArray> ir_arrays;
3665 TF_ASSIGN_OR_RETURN(
3666 std::unique_ptr<KernelThunk> kernel_thunk,
3667 BuildKernelThunkForMlir(op, {init_value, dest}, Thunk::ThunkInfo(), {},
3668 &ir_arrays));
3669 const llvm_ir::IrArray init_array = ir_arrays[0];
3670 const llvm_ir::IrArray dest_array = ir_arrays[1];
3671
3672 const Shape dest_shape = TypeToShape(dest.getType());
3673 LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
3674 dest_shape, ir_emitter_context_->gpu_device_info());
3675 UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
3676 ir_emitter_context_->llvm_module());
3677
3678 std::string name = mlir::GetNameFromLoc(op->getLoc());
3679 TF_RETURN_IF_ERROR(ParallelLoopEmitter(
3680 [=](const IrArray::Index& index) {
3681 return init_array.EmitReadArrayElement(index, &b_);
3682 },
3683 dest_array, launch_dimensions, &b_)
3684 .EmitLoop(mlir::GetNameFromLoc(op->getLoc())));
3685
3686 // Convert unique_ptr<KernelThunk> to StatusOr<unique_ptr<Thunk>>.
3687 return {std::move(kernel_thunk)};
3688 }
3689
3690 StatusOr<std::unique_ptr<Thunk>>
BuildFusedInitializerThunkForMlir(mlir::lmhlo::FusionOp fusion,int output_index)3691 IrEmitterUnnested::BuildFusedInitializerThunkForMlir(
3692 mlir::lmhlo::FusionOp fusion, int output_index) {
3693 auto reduce = mlir::dyn_cast_or_null<mlir::mhlo::ReduceOp>(
3694 fusion.getFusionResults()[output_index].getDefiningOp());
3695
3696 TF_RET_CHECK(reduce);
3697 TF_RET_CHECK(reduce.getNumResults() == 1);
3698
3699 mlir::Value init_value = reduce.init_values()[0];
3700 mlir::Value dest = fusion.getOutputBuffers()[output_index];
3701 TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> constant_init_thunk,
3702 TryBuildConstantInitializerThunk(init_value, dest));
3703 if (constant_init_thunk) {
3704 return {std::move(constant_init_thunk)};
3705 }
3706
3707 auto input_buffers = fusion.getInputBuffers();
3708
3709 std::vector<llvm_ir::IrArray> ir_arrays;
3710 TF_ASSIGN_OR_RETURN(
3711 std::unique_ptr<KernelThunk> kernel_thunk,
3712 BuildKernelThunkForMlir(fusion, Thunk::ThunkInfo(), {}, &ir_arrays));
3713 const llvm_ir::IrArray dest_array =
3714 ir_arrays[input_buffers.size() + output_index];
3715
3716 const Shape dest_shape = TypeToShape(dest.getType());
3717 LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
3718 dest_shape, ir_emitter_context_->gpu_device_info());
3719 UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
3720 ir_emitter_context_->llvm_module());
3721
3722 const HloComputation* fused_computation =
3723 *GetOrCreateSubComputationFromRegion(&fusion.region(),
3724 /*is_fusion=*/true);
3725
3726 // If init_value was fused into this reduce we have to generate it first.
3727 GpuElementalIrEmitter elemental_emitter(hlo_module_config_,
3728 ir_emitter_context_->llvm_module(),
3729 &b_, GetNestedComputer());
3730
3731 FusedIrEmitter fused_emitter(&elemental_emitter);
3732 for (int i = 0; i < fused_computation->num_parameters(); i++) {
3733 fused_emitter.BindGenerator(
3734 fused_computation->parameter_instruction(i),
3735 [this, &ir_arrays, i](llvm_ir::IrArray::Index index) {
3736 return ir_arrays[i].EmitReadArrayElement(index, &b_);
3737 });
3738 }
3739 HloInstruction* instr = fused_computation->root_instruction();
3740 if (instr->opcode() != HloOpcode::kTuple) {
3741 CHECK_EQ(0, output_index);
3742 } else {
3743 instr = instr->mutable_operand(output_index);
3744 }
3745 TF_RET_CHECK(instr->shape().IsArray());
3746 TF_ASSIGN_OR_RETURN(auto generator,
3747 fused_emitter.GetGenerator(instr->operand(1)));
3748 TF_RETURN_IF_ERROR(
3749 ParallelLoopEmitter(generator, dest_array, launch_dimensions, &b_)
3750 .EmitLoop(mlir::GetNameFromLoc(fusion.getLoc())));
3751 return {std::move(kernel_thunk)};
3752 }
3753
3754 namespace {
3755
3756 // Checks that the buffers corresponding to the given two HLOs share the same
3757 // allocation.
CheckHloBuffersShareAllocation(const HloInstruction * a,const HloInstruction * b,const ShapeIndex & index,const BufferAssignment & buffer_assignment)3758 Status CheckHloBuffersShareAllocation(
3759 const HloInstruction* a, const HloInstruction* b, const ShapeIndex& index,
3760 const BufferAssignment& buffer_assignment) {
3761 const BufferAllocation::Slice slice_a =
3762 buffer_assignment.GetUniqueSlice(a, index).ConsumeValueOrDie();
3763 const BufferAllocation::Slice slice_b =
3764 buffer_assignment.GetUniqueSlice(b, index).ConsumeValueOrDie();
3765 if (slice_a != slice_b) {
3766 return InternalError(
3767 "instruction %s %s does not share allocation with instruction %s %s",
3768 a->ToString(), slice_a.ToString(), b->ToString(), slice_b.ToString());
3769 }
3770 return Status::OK();
3771 }
3772
3773 // Checks that all buffers used during while loop iteration share the same
3774 // buffer allocation. This includes buffers for while result, while init
3775 // operand, condition parameter, body parameter and body result.
3776 // Returns OK on success, error status otherwise.
CheckWhileBuffersShareAllocation(const HloInstruction * xla_while,const BufferAssignment & buffer_assignment)3777 Status CheckWhileBuffersShareAllocation(
3778 const HloInstruction* xla_while,
3779 const BufferAssignment& buffer_assignment) {
3780 return ShapeUtil::ForEachSubshapeWithStatus(
3781 xla_while->shape(),
3782 [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status {
3783 const HloInstruction* condition_parameter =
3784 xla_while->while_condition()->parameter_instruction(0);
3785 const HloComputation* body = xla_while->while_body();
3786 const HloInstruction* body_parameter = body->parameter_instruction(0);
3787 const HloInstruction* body_result = body->root_instruction();
3788 TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
3789 xla_while, xla_while->operand(0), index, buffer_assignment));
3790 TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
3791 xla_while, condition_parameter, index, buffer_assignment));
3792 TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
3793 xla_while, body_parameter, index, buffer_assignment));
3794 TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
3795 xla_while, body_result, index, buffer_assignment));
3796 return Status::OK();
3797 });
3798 }
3799
3800 // Checks that the buffers used in a conditional instruction are shared with the
3801 // operands and result as follows:
3802 // * The result buffer of the conditional should share the allocation with the
3803 // result buffers of each branch computation.
3804 // * The buffer of operand b+1 should share the allocation with the buffer of
3805 // the parameter 0 instruction of the b'th computation.
CheckConditionalBuffersShareAllocation(const HloInstruction * conditional,const BufferAssignment & buffer_assignment)3806 Status CheckConditionalBuffersShareAllocation(
3807 const HloInstruction* conditional,
3808 const BufferAssignment& buffer_assignment) {
3809 TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
3810 conditional->shape(),
3811 [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status {
3812 for (auto branch_computation : conditional->branch_computations()) {
3813 TF_RETURN_IF_ERROR(CheckHloBuffersShareAllocation(
3814 conditional, branch_computation->root_instruction(), index,
3815 buffer_assignment));
3816 }
3817 return Status::OK();
3818 }));
3819 for (int j = 0; j < conditional->branch_count(); ++j) {
3820 TF_RETURN_IF_ERROR(ShapeUtil::ForEachSubshapeWithStatus(
3821 conditional->operand(j + 1)->shape(),
3822 [&](const Shape& /*subshape*/, const ShapeIndex& index) -> Status {
3823 return CheckHloBuffersShareAllocation(
3824 conditional->operand(j + 1),
3825 conditional->branch_computation(j)->parameter_instruction(0),
3826 index, buffer_assignment);
3827 }));
3828 }
3829 return Status::OK();
3830 }
3831
3832 } // namespace
3833
BuildWhileThunk(const HloInstruction * hlo)3834 StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildWhileThunk(
3835 const HloInstruction* hlo) {
3836 // Check that all while-related buffers share an allocation.
3837 TF_CHECK_OK(CheckWhileBuffersShareAllocation(
3838 hlo, ir_emitter_context_->buffer_assignment()));
3839
3840 // Generate thunk sequence for while 'condition'.
3841 HloComputation* condition = hlo->while_condition();
3842 TF_ASSIGN_OR_RETURN(auto ir_emitter_condition,
3843 IrEmitterUnnested::Create(hlo_module_config_, condition,
3844 ir_emitter_context_));
3845 TF_RETURN_IF_ERROR(condition->Accept(ir_emitter_condition.get()));
3846
3847 // Generate thunk sequence for while 'body'.
3848 HloComputation* body = hlo->while_body();
3849 TF_ASSIGN_OR_RETURN(
3850 auto ir_emitter_body,
3851 IrEmitterUnnested::Create(hlo_module_config_, body, ir_emitter_context_));
3852 TF_RETURN_IF_ERROR(body->Accept(ir_emitter_body.get()));
3853
3854 const auto* index_map = ir_emitter_context_->profile_index_map();
3855 absl::optional<size_t> condition_profile_index, body_profile_index;
3856 if (index_map) {
3857 condition_profile_index = index_map->GetProfileIndexFor(*condition);
3858 body_profile_index = index_map->GetProfileIndexFor(*body);
3859 }
3860
3861 return std::unique_ptr<Thunk>(new WhileThunk(
3862 GetThunkInfo(hlo),
3863 GetAllocationSlice(*condition->root_instruction()), // cond result
3864 ir_emitter_condition->ConsumeThunkSequence(),
3865 ir_emitter_body->ConsumeThunkSequence(), condition_profile_index,
3866 body_profile_index));
3867 }
3868
BuildForThunk(const HloInstruction * hlo,const int64 loop_limit)3869 StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildForThunk(
3870 const HloInstruction* hlo, const int64 loop_limit) {
3871 // Check that all while-related buffers share an allocation.
3872 TF_CHECK_OK(CheckWhileBuffersShareAllocation(
3873 hlo, ir_emitter_context_->buffer_assignment()));
3874
3875 // Generate thunk sequence for while 'body' (will be used a For loop body).
3876 HloComputation* body = hlo->while_body();
3877 TF_ASSIGN_OR_RETURN(
3878 auto ir_emitter_body,
3879 IrEmitterUnnested::Create(hlo_module_config_, body, ir_emitter_context_));
3880 TF_RETURN_IF_ERROR(body->Accept(ir_emitter_body.get()));
3881
3882 const auto* index_map = ir_emitter_context_->profile_index_map();
3883 absl::optional<size_t> body_profile_index;
3884 if (index_map) {
3885 body_profile_index = index_map->GetProfileIndexFor(*body);
3886 }
3887
3888 return std::unique_ptr<Thunk>(new ForThunk(
3889 GetThunkInfo(hlo), loop_limit, ir_emitter_body->ConsumeThunkSequence(),
3890 body_profile_index));
3891 }
3892
BuildConditionalThunk(const HloInstruction * hlo)3893 StatusOr<std::unique_ptr<Thunk>> IrEmitterUnnested::BuildConditionalThunk(
3894 const HloInstruction* hlo) {
3895 // Check that the buffers used in conditional are shared with the operands and
3896 // result appropriately.
3897 TF_CHECK_OK(CheckConditionalBuffersShareAllocation(
3898 hlo, ir_emitter_context_->buffer_assignment()));
3899
3900 std::vector<BufferAllocation::Slice> branch_operands;
3901 std::vector<ThunkSequence> branch_thunks;
3902 std::vector<absl::optional<size_t>> branch_profile_indices;
3903
3904 int branch_count = hlo->branch_count();
3905 branch_thunks.reserve(branch_count);
3906 branch_profile_indices.reserve(branch_count);
3907
3908 const auto* index_map = ir_emitter_context_->profile_index_map();
3909
3910 for (int j = 0; j < branch_count; ++j) {
3911 branch_operands.emplace_back(GetAllocationSlice(*hlo->operand(j + 1)));
3912 HloComputation* branch_computation = hlo->branch_computation(j);
3913 TF_ASSIGN_OR_RETURN(
3914 auto ir_emitter,
3915 IrEmitterUnnested::Create(hlo_module_config_, branch_computation,
3916 ir_emitter_context_));
3917 TF_CHECK_OK(branch_computation->Accept(ir_emitter.get()));
3918 branch_thunks.push_back(std::move(*ir_emitter->ConsumeThunkSequence()));
3919
3920 absl::optional<size_t> profile_index;
3921 if (index_map) {
3922 profile_index = index_map->GetProfileIndexFor(*branch_computation);
3923 }
3924 branch_profile_indices.push_back(profile_index);
3925 }
3926
3927 ConditionalThunkConfig config = GetConditionalThunkConfig(
3928 hlo, std::move(branch_thunks), std::move(branch_profile_indices));
3929 return std::unique_ptr<Thunk>(new ConditionalThunk(
3930 GetThunkInfo(hlo), std::move(config),
3931 GetAllocationSlice(*hlo->operand(0)), branch_operands));
3932 }
3933
EmitTargetElementLoop(const HloInstruction & hlo,const llvm_ir::ElementGenerator & body_emitter)3934 Status IrEmitterUnnested::EmitTargetElementLoop(
3935 const HloInstruction& hlo, const llvm_ir::ElementGenerator& body_emitter) {
3936 return InternalError("This should be unreachable");
3937 }
3938
3939 // Gets the output offset as calculated from thread_id.x (to be applied to the
3940 // offset calculated from block_id and thread_id.y).
GetStartOffsetX(const KernelMappingScheme & mapping_scheme,llvm::Value * thread_id_x,llvm::Type * index_ty,llvm::IRBuilder<> * b)3941 static llvm::Value* GetStartOffsetX(const KernelMappingScheme& mapping_scheme,
3942 llvm::Value* thread_id_x,
3943 llvm::Type* index_ty,
3944 llvm::IRBuilder<>* b) {
3945 auto constant = [&](int64 val) {
3946 return llvm::ConstantInt::get(index_ty, val);
3947 };
3948 if (mapping_scheme.GetIndexingOrder() == kStridedIndexingX) {
3949 return thread_id_x;
3950 } else if (mapping_scheme.GetIndexingOrder() == kStridedLinearIndexingX) {
3951 return b->CreateMul(thread_id_x, constant(mapping_scheme.GetVectorSize()));
3952 }
3953 CHECK_EQ(mapping_scheme.GetIndexingOrder(), kLinearIndexingX);
3954 int64 x_num_steps =
3955 mapping_scheme.GetTileSizeX() / mapping_scheme.GetNumThreadsX();
3956 return b->CreateMul(thread_id_x, constant(x_num_steps));
3957 }
3958
3959 // Calls `emit_elem_function()` `x_num_steps` times. If
3960 // `vector_size`==1, then each element index passed to
3961 // `emit_elem_function()` will be separated by `step_x`. If `vector_size`>1,
3962 // then it must be a multiple of `x_num_steps`. In that case, it
3963 // triggers a different indexing order that is vectorizable by
3964 // LLVM. It generates many groups of calls to `emit_elem_function`. Each
3965 // group is separated by `step_x` elements. Inside a group, elements
3966 // are consecutive. If `check_x_tile_bounds` is true, then it will check
3967 // if the element index is in bound compared to `tile_width` before
3968 // calling `emit_elem_function`.
UnrollInnerTileLoop(bool check_x_tile_bounds,int64 x_num_steps,int64 step_x,int64 vector_size,const string & loop_name,KernelSupportLibrary * ksl,llvm::Value * start_offset_x,llvm::Value * y_loc,llvm::Value * tile_width,const IrArray::Index & source_idx,llvm::IRBuilder<> * b,const IrEmitterUnnested::EmitElementFunction * emit_elem_function)3969 static void UnrollInnerTileLoop(
3970 bool check_x_tile_bounds, int64 x_num_steps, int64 step_x,
3971 int64 vector_size, const string& loop_name, KernelSupportLibrary* ksl,
3972 llvm::Value* start_offset_x, llvm::Value* y_loc, llvm::Value* tile_width,
3973 const IrArray::Index& source_idx, llvm::IRBuilder<>* b,
3974 const IrEmitterUnnested::EmitElementFunction* emit_elem_function) {
3975 llvm::Type* index_ty = tile_width->getType();
3976 auto constant = [&](int64 val) {
3977 return llvm::ConstantInt::get(index_ty, val);
3978 };
3979 IrArray::Index source_idx_x_base = source_idx.AddOffsetToDim(y_loc, kDimY, b);
3980 for (int64 j = 0; j < x_num_steps / vector_size; j++) {
3981 for (int64 i = 0; i < vector_size; i++) {
3982 int64 linear_index = j * vector_size + i;
3983 llvm::Value* x_loc = b->CreateAdd(constant(j * step_x * vector_size + i),
3984 start_offset_x, "x_loc");
3985 IrArray::Index source_idx_x = source_idx_x_base.AddOffsetToDim(
3986 constant(j * step_x * vector_size + i), kDimX, b);
3987 auto emit_element = [&] {
3988 return (*emit_elem_function)(source_idx_x, y_loc, x_loc, linear_index);
3989 };
3990 if (check_x_tile_bounds) {
3991 ksl->If(loop_name + "_x_in_tile", b->CreateICmpULT(x_loc, tile_width),
3992 emit_element);
3993 } else {
3994 emit_element();
3995 }
3996 }
3997 }
3998 }
3999
EmitTile(const KernelMappingScheme & mapping_scheme,const IrArray::Index & tile_origin_index,const string & loop_name,KernelSupportLibrary * ksl,const ThreadIdInfo & thread_id_info,llvm::Value * tile_height,llvm::Value * tile_width,const IrEmitterUnnested::EmitElementFunction & emit_elem_function)4000 void IrEmitterUnnested::EmitTile(
4001 const KernelMappingScheme& mapping_scheme,
4002 const IrArray::Index& tile_origin_index, const string& loop_name,
4003 KernelSupportLibrary* ksl, const ThreadIdInfo& thread_id_info,
4004 llvm::Value* tile_height, llvm::Value* tile_width,
4005 const IrEmitterUnnested::EmitElementFunction& emit_elem_function) {
4006 llvm::Type* index_ty = tile_width->getType();
4007 auto constant = [&](int64 val) {
4008 return llvm::ConstantInt::get(index_ty, val);
4009 };
4010 int64 num_threads_x = mapping_scheme.GetNumThreadsX();
4011 llvm::Value* num_threads_y = constant(mapping_scheme.GetNumThreadsY());
4012 int64 tile_size_x = mapping_scheme.GetTileSizeX();
4013
4014 int64 x_num_steps = tile_size_x / num_threads_x;
4015 llvm::Value* start_offset_x = GetStartOffsetX(
4016 mapping_scheme, thread_id_info.thread_id_x, index_ty, &b_);
4017
4018 // Using dilated mapping scheme, each thread steps with a stride of number
4019 // of threads.
4020 // Otherwise, the stride is one, but we multiply each offset by the limit of
4021 // number of steps which can be made.
4022 int64 step_x =
4023 mapping_scheme.GetIndexingOrder() == kLinearIndexingX ? 1 : num_threads_x;
4024 int64 vector_size = mapping_scheme.GetVectorSize();
4025
4026 IrArray::Index source_idx =
4027 tile_origin_index.AddOffsetToDim(start_offset_x, kDimX, &b_);
4028
4029 auto ceil_of_ratio = [&](llvm::Value* a, llvm::Value* b) {
4030 return b_.CreateUDiv(b_.CreateAdd(b_.CreateAdd(a, b), constant(-1)), b);
4031 };
4032
4033 // True iff all threads always execute all instructions in the tiling
4034 // dimension X.
4035 bool x_tile_fits =
4036 mapping_scheme.GetDimsInElems()[kDimX] % tile_size_x == 0 &&
4037 mapping_scheme.GetRowContiguous();
4038
4039 // The outer loop below is simply doing:
4040 //
4041 // for (int y_loc=thread_id_y; y_loc<tile_height; y_loc+=num_threads_y)
4042 //
4043 //
4044 // However, in order to avoid an LLVM optimization triggering the ptxas bug,
4045 // we write this loop in a convoluted way:
4046 //
4047 // y_bound = ceil_of_ratio(tile_height - thread_id_y, num_threads_y)
4048 // for (int y_indvar=0; y_indvar<y_bound; y_indvar+=1)
4049 // y_loc = thread_id_y + y_indvar * num_threads_y
4050 //
4051 // TODO(cheshire): Once ptxas is fixed and TF switches to it, remove the
4052 // workaround.
4053 ksl->For(
4054 loop_name + "_y_in_tile",
4055 /*start=*/constant(0),
4056 /*end=*/
4057 ceil_of_ratio(b_.CreateSub(tile_height, thread_id_info.thread_id_y),
4058 num_threads_y),
4059 /*step=*/constant(1), [&](llvm::Value* y_indvar) {
4060 llvm::Value* y_loc = b_.CreateAdd(
4061 thread_id_info.thread_id_y, b_.CreateMul(y_indvar, num_threads_y));
4062 auto unroll_inner_tile_loop = [&](bool check_x_tile_bounds) {
4063 return UnrollInnerTileLoop(check_x_tile_bounds, x_num_steps, step_x,
4064 vector_size, loop_name, ksl,
4065 start_offset_x, y_loc, tile_width,
4066 source_idx, &b_, &emit_elem_function);
4067 };
4068
4069 // Only take this path when we unroll in a way vectorizable by
4070 // LLVM. Special case when the tile doesn't fit completely for even
4071 // row size. For odd row size every other row isn't aligned to the
4072 // vectorized size, so it can't be vectorized by LLVM.
4073 if (!x_tile_fits &&
4074 mapping_scheme.GetIndexingOrder() == kStridedLinearIndexingX) {
4075 ksl->If(
4076 loop_name + "_is_full_tile",
4077 // For the last block, tile_width will be the number of
4078 // elements left.
4079 b_.CreateICmpEQ(constant(mapping_scheme.GetTileSizeX()),
4080 tile_width),
4081 [&] { unroll_inner_tile_loop(/*check_x_tile_bounds=*/false); },
4082 [&] { unroll_inner_tile_loop(/*check_x_tile_bounds=*/true); });
4083 } else {
4084 unroll_inner_tile_loop(/*check_x_tile_bounds=*/!x_tile_fits);
4085 }
4086 });
4087 }
4088
4089 // Emits code to process a tensor element in a tile for the given kCopy HLO that
4090 // performs a 0-2-1 transpose.
4091 //
4092 // index: The index for the first output element in the normalized tensor. The
4093 // normalized tensor is the resulting tensor after collapsing contiguous
4094 // dimensions that play the same role in the transpose.
4095 // mapping_scheme: Kernel mapping scheme specifying the tiling
EmitTileElementForCopy(const Shape & output_shape,const llvm_ir::IrArray & output_array,const llvm_ir::IrArray::Index & index,const KernelMappingScheme & mapping_scheme,llvm::Value * y_loc,llvm::Value * x_loc,absl::Span<llvm::Value * const> param_shmem_buffers)4096 void IrEmitterUnnested::EmitTileElementForCopy(
4097 const Shape& output_shape, const llvm_ir::IrArray& output_array,
4098 const llvm_ir::IrArray::Index& index,
4099 const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
4100 llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers) {
4101 // TODO(jlebar): Add AA metadata to this load.
4102 llvm::Instruction* load_from_shmem_buffer =
4103 Load(GEP(param_shmem_buffers[0], {b_.getInt64(0), x_loc, y_loc}),
4104 "output_element");
4105 Shape output_reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout(
4106 output_shape.element_type(), mapping_scheme.GetDimsInElems());
4107 // When the output_reduced_shape is a 0-2-1 transpose of the input shape,
4108 // the 0-2-1 transpose is achieved through EmitWriteArrayElement.
4109 output_array.CastToShape(output_reduced_shape, &b_)
4110 .EmitWriteArrayElement(index, load_from_shmem_buffer, &b_);
4111 }
4112
GetUnnormalizedIndex(const IrArray::Index & normalized_shape_index,const Shape & unnormalized_shape,llvm::IRBuilder<> * b_,const KernelMappingScheme & kernel_mapping_scheme)4113 static IrArray::Index GetUnnormalizedIndex(
4114 const IrArray::Index& normalized_shape_index,
4115 const Shape& unnormalized_shape, llvm::IRBuilder<>* b_,
4116 const KernelMappingScheme& kernel_mapping_scheme) {
4117 DCHECK_EQ(normalized_shape_index.size(), 3);
4118 // If the normalization only add a new dimensions of size 1,
4119 // generate simpler indexing. LLVM doesn't always simplify the more
4120 // complicated indexing and this prevents it from vectorizing some
4121 // cases. We do this only for major_to_minor memory layout.
4122 if (unnormalized_shape.rank() == 2 && unnormalized_shape.has_layout() &&
4123 unnormalized_shape.dimensions()[0] == normalized_shape_index.dims()[1] &&
4124 unnormalized_shape.dimensions()[1] == normalized_shape_index.dims()[2] &&
4125 unnormalized_shape.layout().minor_to_major(1) == 0) {
4126 CHECK_EQ(normalized_shape_index.dims()[0], 1);
4127 auto multidim = normalized_shape_index.multidim();
4128 return IrArray::Index({multidim[1], multidim[2]}, unnormalized_shape,
4129 normalized_shape_index.GetType());
4130 }
4131 llvm::Value* linear = normalized_shape_index.Linearize(
4132 kernel_mapping_scheme.GetDimsInElems(), b_);
4133 return IrArray::Index(linear, unnormalized_shape, b_);
4134 }
4135
4136 // Emits code to process a tensor element in a tile for the given kLoop fusion
4137 // HLO containing parameters that are 0-2-1 transpose of its outputs.
4138 //
4139 // index: The index for the first output element in the normalized tensor, that
4140 // is the resulting tensor after collapsing contiguous dimensions that play
4141 // the same role in the transpose.
4142 // kernel_info: Other information to support the kernel code generation.
EmitTileElementForFusion(mlir::lmhlo::FusionOp fusion,absl::Span<const llvm_ir::IrArray> operand_arrays,absl::Span<const llvm_ir::IrArray> output_arrays,const llvm_ir::IrArray::Index & index,const KernelMappingScheme & mapping_scheme,llvm::Value * y_loc,llvm::Value * x_loc,absl::Span<llvm::Value * const> param_shmem_buffers)4143 void IrEmitterUnnested::EmitTileElementForFusion(
4144 mlir::lmhlo::FusionOp fusion,
4145 absl::Span<const llvm_ir::IrArray> operand_arrays,
4146 absl::Span<const llvm_ir::IrArray> output_arrays,
4147 const llvm_ir::IrArray::Index& index,
4148 const KernelMappingScheme& mapping_scheme, llvm::Value* y_loc,
4149 llvm::Value* x_loc, absl::Span<llvm::Value* const> param_shmem_buffers) {
4150 const HloComputation* fused_computation =
4151 *GetOrCreateSubComputationFromRegion(&fusion.region(),
4152 /*is_fusion=*/true);
4153 GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
4154 GetNestedComputer());
4155 FusedIrEmitter fused_emitter(&elem_emitter);
4156 for (int i = 0; i < operand_arrays.size(); i++) {
4157 llvm_ir::ElementGenerator gen;
4158 if (llvm::Value* param_tile_buffer = param_shmem_buffers[i]) {
4159 gen = [this, param_tile_buffer, x_loc,
4160 y_loc](llvm_ir::IrArray::Index index) {
4161 // TODO(jlebar): Add AA metadata to this load. Tile buffers are
4162 // global variables, so LLVM's points-to analysis doesn't help us
4163 // much. And we want the AA info to be present before address
4164 // spaces are inferred (which is pretty late in the pipeline), so
4165 // even if we had address-space-based AA in LLVM, it wouldn't help
4166 // us much here.
4167 return b_.CreateLoad(
4168 b_.CreateGEP(param_tile_buffer,
4169 {index.GetConstantWithIndexType(0), x_loc, y_loc}),
4170 "tiled_buffer");
4171 };
4172 } else {
4173 auto array = operand_arrays[i];
4174 auto name = fused_computation->parameter_instruction(i)->name();
4175 gen = [this, array, name](const llvm_ir::IrArray::Index& index) {
4176 return array.EmitReadArrayElement(index, &b_, name);
4177 };
4178 }
4179 fused_emitter.BindGenerator(fused_computation->parameter_instruction(i),
4180 std::move(gen));
4181 }
4182 IrArray::Index untiled_index = GetUnnormalizedIndex(
4183 index, output_arrays[0].GetShape(), &b_, mapping_scheme);
4184 llvm_ir::ElementGenerator output_generator =
4185 *fused_emitter.GetGenerator(fused_computation->root_instruction());
4186 llvm::Value* output_value = output_generator(untiled_index).ValueOrDie();
4187 if (output_arrays.size() > 1) {
4188 DCHECK(output_value->getType()->isStructTy());
4189 DCHECK_EQ(output_value->getType()->getStructNumElements(),
4190 output_arrays.size() - 1);
4191 for (int64 i = 0; i < output_arrays.size() - 1; ++i) {
4192 output_arrays[i].EmitWriteArrayElement(
4193 untiled_index, ExtractValue(output_value, i), &b_);
4194 }
4195 } else {
4196 output_arrays[0].EmitWriteArrayElement(untiled_index, output_value, &b_);
4197 }
4198 }
4199
GetReduceFromUnnestedMlir(mlir::Operation * unnested_hlo,int index)4200 static mlir::Operation* GetReduceFromUnnestedMlir(mlir::Operation* unnested_hlo,
4201 int index) {
4202 if (mlir::isa<mlir::lmhlo::ReduceOp>(unnested_hlo)) {
4203 CHECK_EQ(0, index);
4204 return unnested_hlo;
4205 }
4206 if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(unnested_hlo)) {
4207 auto results = fusion.getFusionResults();
4208 CHECK(index < results.size())
4209 << MlirToString(unnested_hlo) << " vs " << index;
4210 return results[index].getDefiningOp();
4211 }
4212 return nullptr;
4213 }
4214
EmitPrologueForReduction(mlir::Operation * unnested_hlo,absl::Span<const int> instr_index_group,HloComputation * fused_computation,FusedIrEmitter * fused_emitter,absl::Span<const llvm_ir::IrArray> operand_ir_arrays,absl::Span<const llvm_ir::IrArray> result_ir_arrays,ReductionCodegenInfo * reduction_info)4215 void IrEmitterUnnested::EmitPrologueForReduction(
4216 mlir::Operation* unnested_hlo, absl::Span<const int> instr_index_group,
4217 HloComputation* fused_computation, FusedIrEmitter* fused_emitter,
4218 absl::Span<const llvm_ir::IrArray> operand_ir_arrays,
4219 absl::Span<const llvm_ir::IrArray> result_ir_arrays,
4220 ReductionCodegenInfo* reduction_info) {
4221 VLOG(10) << "Emit prologue for reduction: " << MlirToString(unnested_hlo);
4222 mlir::Operation* first_reduce = nullptr;
4223 for (int index : instr_index_group) {
4224 mlir::Operation* reduce_inst =
4225 GetReduceFromUnnestedMlir(unnested_hlo, index);
4226
4227 if (!IsReductionFromOrToContiguousDimensions(reduce_inst)) {
4228 continue;
4229 }
4230
4231 auto results = GetHloOutputs(reduce_inst);
4232 CHECK_EQ(1, results.size());
4233 Shape reduce_inst_shape = TypeToShape(results[0].getType());
4234
4235 VLOG(10) << "Emit prologue for reduction: " << MlirToString(reduce_inst);
4236 if (first_reduce == nullptr) {
4237 first_reduce = reduce_inst;
4238 } else {
4239 CHECK(absl::c_equal(
4240 first_reduce->getAttrOfType<mlir::DenseIntElementsAttr>("dimensions"),
4241 reduce_inst->getAttrOfType<mlir::DenseIntElementsAttr>(
4242 "dimensions")));
4243 }
4244
4245 AddressVector* reduction_input_addresses =
4246 reduction_info->GetMutableReductionInputAddresses();
4247 llvm::Type* element_type = llvm_ir::PrimitiveTypeToIrType(
4248 reduce_inst_shape.element_type(), ir_emitter_context_->llvm_module());
4249 llvm::AllocaInst* reduction_input_address =
4250 llvm_ir::EmitAllocaAtFunctionEntry(element_type,
4251 "reduction_input_address", &b_);
4252 reduction_input_addresses->push_back(reduction_input_address);
4253
4254 int num_partial_results = reduction_info->GetNumPartialResults();
4255 AddressVector* partial_result_addresses =
4256 reduction_info->GetMutablePartialResultAddresses();
4257 llvm::AllocaInst* partial_result_address =
4258 llvm_ir::EmitAllocaAtFunctionEntryWithCount(
4259 element_type, /*ArraySize=*/b_.getInt32(num_partial_results),
4260 ("partial_reduction_result." + llvm::Twine(index)).str(), &b_);
4261 partial_result_addresses->push_back(partial_result_address);
4262
4263 // Initialize the partial result with the initial value of the reduction.
4264 llvm::Value* init_ir_value;
4265 if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(unnested_hlo)) {
4266 const HloInstruction* reduce_hlo = fused_computation->root_instruction();
4267 if (reduce_hlo->opcode() == HloOpcode::kTuple) {
4268 reduce_hlo = reduce_hlo->operand(index);
4269 }
4270 const HloInstruction* init_value = reduce_hlo->operand(1);
4271
4272 init_ir_value = (*fused_emitter->GetGenerator(
4273 init_value))(IrArray::Index(b_.getInt32Ty()))
4274 .ValueOrDie();
4275 } else {
4276 init_ir_value = operand_ir_arrays[1].EmitReadArrayElement(
4277 IrArray::Index(b_.getInt32Ty()), &b_);
4278 }
4279
4280 for (int i = 0; i < num_partial_results; ++i) {
4281 Store(init_ir_value,
4282 InBoundsGEP(partial_result_address, {b_.getInt32(i)}));
4283 }
4284 reduction_info->GetMutableInitialValues()->push_back(init_ir_value);
4285
4286 auto& mapping_scheme = reduction_info->GetKernelMappingScheme();
4287 int64 num_threads_x = mapping_scheme.GetNumThreadsX();
4288 llvm::Type* primitive_type = llvm_ir::PrimitiveTypeToIrType(
4289 reduce_inst_shape.element_type(), module_);
4290 llvm::Type* buffer_type = [&] {
4291 if (reduction_info->IsRowReduction()) {
4292 // Allocate __shared__ cache[num_partial_results][kWarpSize].
4293 return llvm::ArrayType::get(
4294 llvm::ArrayType::get(primitive_type, kWarpSize),
4295 num_partial_results);
4296 } else {
4297 // Allocate __shared__
4298 // cache[num_partial_results][num_threads][num_threads + 1], where
4299 // num_threads == num_threads_x == num_threads_y. The "+1" is used to
4300 // avoid bank conflicts.
4301 CHECK_EQ(num_threads_x, mapping_scheme.GetNumThreadsY());
4302 return llvm::ArrayType::get(
4303 llvm::ArrayType::get(
4304 llvm::ArrayType::get(primitive_type, num_threads_x + 1),
4305 num_threads_x),
4306 num_partial_results);
4307 }
4308 }();
4309 llvm::GlobalVariable* shared_cache_per_reduce =
4310 llvm_ir::AllocateSharedMemoryTile(b_.GetInsertBlock()->getModule(),
4311 buffer_type,
4312 absl::StrCat("shared_cache_", index));
4313 reduction_info->GetMutableSharedCache()->push_back(shared_cache_per_reduce);
4314 }
4315 CHECK(first_reduce);
4316 }
4317
EmitFullWarpShuffleDownLoopForAllReduces(absl::Span<HloComputation * const> reducers,absl::Span<llvm::AllocaInst * const> partial_result_addresses,int threads_per_block)4318 void IrEmitterUnnested::EmitFullWarpShuffleDownLoopForAllReduces(
4319 absl::Span<HloComputation* const> reducers,
4320 absl::Span<llvm::AllocaInst* const> partial_result_addresses,
4321 int threads_per_block) {
4322 CHECK_EQ(reducers.size(), partial_result_addresses.size());
4323 for (int i = 0; i != reducers.size(); i++) {
4324 EmitFullWarpShuffleDownLoopForReduce(
4325 reducers[i], partial_result_addresses[i]->getType()->getElementType(),
4326 partial_result_addresses[i], threads_per_block);
4327 }
4328 }
4329
EmitFullWarpShuffleDownLoopForReduce(HloComputation * reducer,llvm::Type * element_type,llvm::Value * partial_result_address,int threads_per_block)4330 void IrEmitterUnnested::EmitFullWarpShuffleDownLoopForReduce(
4331 HloComputation* reducer, llvm::Type* element_type,
4332 llvm::Value* partial_result_address, int threads_per_block) {
4333 // This only works when the block size is a multiple of 32 threads.
4334 CHECK_EQ(threads_per_block % 32, 0);
4335 for (int distance = 16; distance >= 1; distance /= 2) {
4336 int bit_width = llvm_ir::GetSizeInBits(element_type);
4337 llvm::Value* result_from_other_lane = llvm_ir::EmitAllocaAtFunctionEntry(
4338 element_type, "result_from_other_lane", &b_);
4339 // Bitcast cannot be applied to aggregate types (even packed ones), so
4340 // we bitcast addresses of load/store to intN* of the same bit-width.
4341 llvm::Type* shuffled_value_type =
4342 element_type->isStructTy() ? b_.getIntNTy(bit_width) : element_type;
4343 auto convert_pointer_for_shuffle = [&](llvm::Value* ptr) {
4344 return b_.CreatePointerBitCastOrAddrSpaceCast(
4345 ptr, shuffled_value_type->getPointerTo());
4346 };
4347 llvm::Value* partial_result =
4348 Load(convert_pointer_for_shuffle(partial_result_address),
4349 "partial_reduction_result");
4350 Store(EmitFullWarpShuffleDown(partial_result, b_.getInt32(distance), &b_),
4351 convert_pointer_for_shuffle(result_from_other_lane));
4352 TF_CHECK_OK(EmitCallToNestedComputation(
4353 *reducer, {partial_result_address, result_from_other_lane},
4354 partial_result_address));
4355 }
4356 }
4357
4358 // Given the IrArray index of a reduction input, returns the linear address of
4359 // the reduction output as if the reduction were going to keep the input shape
4360 // with the dimensions being reduced moved.
GetUntransposedOutputLinearAddress(llvm::IRBuilder<> * b,const llvm_ir::IrArray::Index & index,const ReductionCodegenInfo & reduction_info)4361 static llvm::Value* GetUntransposedOutputLinearAddress(
4362 llvm::IRBuilder<>* b, const llvm_ir::IrArray::Index& index,
4363 const ReductionCodegenInfo& reduction_info) {
4364 const KernelMappingScheme& kernel_mapping_scheme =
4365 reduction_info.GetKernelMappingScheme();
4366 if (reduction_info.IsRowReduction()) {
4367 // For row-reduction, y-coordinate determines which row we write into.
4368 return index[kDimY];
4369 }
4370 // For column reduction, we get the transposed address.
4371 absl::Span<const int64> dims_in_elem = kernel_mapping_scheme.GetDimsInElems();
4372 llvm::Value* x_dim_size = index.GetConstantWithIndexType(dims_in_elem[kDimX]);
4373 llvm::Value* x_block_offset = b->CreateMul(index[kDimZ], x_dim_size);
4374 return b->CreateAdd(x_block_offset, index[kDimX]);
4375 }
4376
EmitEpilogueForReduction(llvm::Type * index_ty,mlir::Operation * unnested_hlo,absl::Span<const int> instr_index_group,absl::Span<const llvm_ir::IrArray> result_ir_arrays,absl::Span<HloComputation * const> reducers,const ReductionCodegenInfo & reduction_info,const TilingKernelInfo & tiling_kernel_info)4377 void IrEmitterUnnested::EmitEpilogueForReduction(
4378 llvm::Type* index_ty, mlir::Operation* unnested_hlo,
4379 absl::Span<const int> instr_index_group,
4380 absl::Span<const llvm_ir::IrArray> result_ir_arrays,
4381 absl::Span<HloComputation* const> reducers,
4382 const ReductionCodegenInfo& reduction_info,
4383 const TilingKernelInfo& tiling_kernel_info) {
4384 const KernelMappingScheme& mapping_scheme =
4385 reduction_info.GetKernelMappingScheme();
4386 auto constant = [&](uint64 c) -> llvm::Constant* {
4387 return llvm::ConstantInt::get(index_ty, c);
4388 };
4389
4390 IrEmitterUnnested::ThreadIdInfo thread_id_info =
4391 EmitThreadIdInfo(mapping_scheme.GetThreadsPerBlock(), index_ty,
4392 mapping_scheme.GetNumThreadsX());
4393
4394 IrArray::Index start_offset = [&] {
4395 llvm::Value* x_loc = thread_id_info.thread_id_x;
4396 llvm::Value* y_loc = thread_id_info.thread_id_y;
4397 if (!reduction_info.IsRowReduction()) {
4398 std::swap(x_loc, y_loc);
4399 }
4400 llvm::Value* start_offset_x =
4401 GetStartOffsetX(mapping_scheme, x_loc, index_ty, &b_);
4402 return tiling_kernel_info.tile_origin.AddOffsetToDim(y_loc, kDimY, &b_)
4403 .AddOffsetToDim(start_offset_x, kDimX, &b_);
4404 }();
4405
4406 absl::Span<llvm::AllocaInst* const> partial_result_addresses =
4407 reduction_info.GetPartialResultAddresses();
4408
4409 int num_partial_results = reduction_info.GetNumPartialResults();
4410
4411 // Emit an atomic operation that accumulates the partial reduction to the
4412 // output element. For row reduction, this is only for lane 0 due to the
4413 // if-statement emitted above.
4414 //
4415 // `i` is the compacted index for contiguous-dimension reductions. It's used
4416 // for accessing `reduction_info` and `reducers`, which are also compacted.
4417 int i = -1;
4418 for (int index : instr_index_group) {
4419 mlir::Operation* reduce_hlo =
4420 GetReduceFromUnnestedMlir(unnested_hlo, index);
4421 if (!IsReductionFromOrToContiguousDimensions(reduce_hlo)) {
4422 continue;
4423 }
4424 i++;
4425 auto operand_shape = TypeToShape(reduce_hlo->getOperand(0).getType());
4426 Shape reduction_kept_element_shape = ShapeUtil::FilterDimensions(
4427 [&](int64 dim) {
4428 return !absl::c_linear_search(
4429 reduce_hlo->getAttrOfType<mlir::DenseIntElementsAttr>(
4430 "dimensions"),
4431 dim);
4432 },
4433 operand_shape);
4434 for (int j = 0; j < num_partial_results; ++j) {
4435 llvm::Value* untransposed_output_linear_address =
4436 GetUntransposedOutputLinearAddress(
4437 &b_, start_offset.AddOffsetToDim(constant(j), kDimX, &b_),
4438 reduction_info);
4439
4440 // A reduction is allowed to transpose its output. For example, suppose
4441 // we are reducing the second dimension of f32[10,20,30]{3,2,1}. We are
4442 // allowed to produce as output either f32[10,30]{1,0} (no transpose) or
4443 // f32[10,30]{0,1} (transposing the two output dims).
4444 //
4445 // At this point in the function we have a "partial sum" of input elements
4446 // (stored in partial_result_addresses), and we need to accumulate it into
4447 // the correct output element.
4448 auto output_array = result_ir_arrays[index];
4449 IrArray::Index element_index(
4450 /*linear=*/untransposed_output_linear_address,
4451 reduction_kept_element_shape, &b_);
4452 IrArray::Index output_index(element_index.multidim(),
4453 output_array.GetShape(),
4454 element_index.GetType());
4455 llvm::Value* output_address = output_array.EmitArrayElementAddress(
4456 output_index, &b_, "output_element_address");
4457 llvm::Value* current_output = b_.CreateInBoundsGEP(
4458 partial_result_addresses[i], {constant(j)}, "current_output");
4459
4460 llvm::GlobalVariable* shared_cache = reduction_info.GetSharedCache()[i];
4461
4462 // __shared__ memory uses a different address space, so we cast it to
4463 // global address space before writing or reading.
4464 auto shared_to_global = [&](llvm::Value* input, llvm::Twine name = "") {
4465 return b_.CreateAddrSpaceCast(
4466 input,
4467 llvm::PointerType::get(input->getType()->getPointerElementType(),
4468 /*AddressSpace=*/0),
4469 name);
4470 };
4471
4472 auto is_zero = [&](llvm::Value* value) {
4473 return b_.CreateICmpEQ(value, constant(0));
4474 };
4475
4476 KernelSupportLibrary ksl(&b_);
4477 llvm::Type* element_type =
4478 partial_result_addresses[i]->getType()->getElementType();
4479 if (reduction_info.IsRowReduction()) {
4480 EmitFullWarpShuffleDownLoopForReduce(
4481 reducers[i], element_type, current_output,
4482 mapping_scheme.GetThreadsPerBlock());
4483 llvm::Value* warp_id =
4484 b_.CreateUDiv(thread_id_info.thread_id_x, constant(kWarpSize));
4485 ksl.If("intra_warp_reduce_write", is_zero(thread_id_info.lane_id), [&] {
4486 llvm::Value* shmem_output_addr =
4487 shared_to_global(b_.CreateInBoundsGEP(
4488 shared_cache, {b_.getInt32(0), constant(j), warp_id}));
4489 b_.CreateStore(b_.CreateLoad(current_output), shmem_output_addr);
4490 });
4491
4492 EmitSyncThreads();
4493 ksl.If("inter_warp_reduce", is_zero(warp_id), [&] {
4494 llvm::Value* block_accum_addr = shared_to_global(b_.CreateInBoundsGEP(
4495 shared_cache,
4496 {b_.getInt32(0), constant(j), thread_id_info.lane_id}));
4497 llvm::Value* initial_value = reduction_info.GetInitialValues()[i];
4498 llvm::Value* initial_value_addr =
4499 shared_to_global(llvm_ir::EmitAllocaAtFunctionEntry(
4500 element_type, "initial_value_addr", &b_));
4501 b_.CreateStore(initial_value, initial_value_addr);
4502
4503 llvm::Value* warp_exists = b_.CreateICmpULT(
4504 thread_id_info.thread_id_x,
4505 constant(mapping_scheme.GetNumThreadsX() / kWarpSize));
4506
4507 llvm::Value* selected_value = b_.CreateSelect(
4508 warp_exists, block_accum_addr, initial_value_addr);
4509
4510 EmitFullWarpShuffleDownLoopForReduce(
4511 reducers[i], element_type,
4512 /*block_accum_addr*/ selected_value,
4513 mapping_scheme.GetThreadsPerBlock());
4514 ksl.If("reduction_atomic_update", is_zero(thread_id_info.thread_id_x),
4515 [&] {
4516 TF_CHECK_OK(EmitAtomicOperationForNestedComputation(
4517 *reducers[i], output_address, block_accum_addr));
4518 });
4519 });
4520
4521 } else {
4522 llvm::Value* shmem_output_addr = shared_to_global(
4523 b_.CreateInBoundsGEP(shared_cache, {b_.getInt32(0), constant(j),
4524 thread_id_info.thread_id_x,
4525 thread_id_info.thread_id_y}),
4526 "shmem_output_address");
4527 llvm::Value* current_output_value = b_.CreateLoad(current_output);
4528 b_.CreateStore(current_output_value, shmem_output_addr);
4529
4530 EmitSyncThreads();
4531
4532 // Get transposed element from shared memory.
4533 llvm::Value* shmem_transposed_addr =
4534 shared_to_global(b_.CreateInBoundsGEP(
4535 shared_cache,
4536 {b_.getInt32(0), constant(j), thread_id_info.thread_id_y,
4537 thread_id_info.thread_id_x},
4538 "shmem_transposed_addr"));
4539
4540 EmitFullWarpShuffleDownLoopForReduce(
4541 reducers[i], element_type, shmem_transposed_addr,
4542 mapping_scheme.GetThreadsPerBlock());
4543
4544 // Some threads in the block are completely outside of the bound of the
4545 // tensor, so they should not write any output at all.
4546 llvm::Value* has_output = b_.CreateAnd(
4547 b_.CreateICmpULT(
4548 GetStartOffsetX(mapping_scheme, thread_id_info.thread_id_y,
4549 index_ty, &b_),
4550 tiling_kernel_info.output_tile_bounds[kDimX]),
4551 b_.CreateICmpULT(thread_id_info.thread_id_x,
4552 tiling_kernel_info.output_tile_bounds[kDimY]));
4553
4554 ksl.If("reduction_atomic_update",
4555 b_.CreateAnd(has_output, is_zero(thread_id_info.lane_id)), [&] {
4556 TF_CHECK_OK(EmitAtomicOperationForNestedComputation(
4557 *reducers[i], output_address, shmem_transposed_addr));
4558 });
4559 }
4560 }
4561 }
4562 }
4563
EmitBlockId()4564 llvm::Value* IrEmitterUnnested::EmitBlockId() {
4565 return gpu::EmitCallToTargetIntrinsic(gpu::TargetIntrinsicID::kBlockIdx, {},
4566 {}, &b_);
4567 }
4568
EmitPrintfWithThreadId(absl::string_view fmt,absl::Span<llvm::Value * const> arguments,absl::optional<int64> thread_id_filter,absl::optional<int64> block_id_filter)4569 void IrEmitterUnnested::EmitPrintfWithThreadId(
4570 absl::string_view fmt, absl::Span<llvm::Value* const> arguments,
4571 absl::optional<int64> thread_id_filter,
4572 absl::optional<int64> block_id_filter) {
4573 llvm::Value* thread_id = EmitThreadId(1024, b_.getInt32Ty());
4574 llvm::Value* block_id = EmitBlockId();
4575 std::vector<llvm::Value*> updated_arguments = {thread_id, block_id};
4576 updated_arguments.insert(updated_arguments.end(), arguments.begin(),
4577 arguments.end());
4578 llvm::Value* constraint = b_.getTrue();
4579 if (thread_id_filter) {
4580 constraint = b_.CreateAnd(
4581 constraint, b_.CreateICmpEQ(thread_id, b_.getInt32(*thread_id_filter)));
4582 }
4583 if (block_id_filter) {
4584 constraint = b_.CreateAnd(
4585 constraint, b_.CreateICmpEQ(block_id, b_.getInt32(*block_id_filter)));
4586 }
4587 KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
4588 ksl.If(constraint, [&] {
4589 xla::gpu::EmitPrintf(absl::StrCat("[TID=%d,BID=%d] ", fmt, "\n"),
4590 updated_arguments, &b_);
4591 });
4592 }
4593
EmitTileElementForReduction(mlir::Operation * unnested_hlo,const Shape & reduction_operand_shape,absl::Span<const int> instr_index_group,HloComputation * fused_computation,FusedIrEmitter * fused_emitter,absl::Span<const llvm_ir::IrArray> operand_ir_arrays,absl::Span<const llvm_ir::IrArray> result_ir_arrays,absl::Span<HloComputation * const> reducers,const llvm_ir::IrArray::Index & index,const ReductionCodegenInfo & reduction_info,int64 x_iter_num)4594 void IrEmitterUnnested::EmitTileElementForReduction(
4595 mlir::Operation* unnested_hlo, const Shape& reduction_operand_shape,
4596 absl::Span<const int> instr_index_group, HloComputation* fused_computation,
4597 FusedIrEmitter* fused_emitter,
4598 absl::Span<const llvm_ir::IrArray> operand_ir_arrays,
4599 absl::Span<const llvm_ir::IrArray> result_ir_arrays,
4600 absl::Span<HloComputation* const> reducers,
4601 const llvm_ir::IrArray::Index& index,
4602 const ReductionCodegenInfo& reduction_info, int64 x_iter_num) {
4603 VLOG(10) << "Emit tile element for reduce " << MlirToString(unnested_hlo);
4604 int partial_result_index = reduction_info.IsRowReduction() ? 0 : x_iter_num;
4605
4606 InlinedVector<llvm_ir::ElementGenerator, 1> input_gens;
4607 std::vector<std::pair<llvm_ir::ElementGenerator, int>> extra_output_gens;
4608
4609 // Construct the ElementGenerator for each reduction and extra output in the
4610 // the group of output instructions.
4611 if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(unnested_hlo)) {
4612 for (int index : instr_index_group) {
4613 mlir::Operation* inst = GetReduceFromUnnestedMlir(unnested_hlo, index);
4614
4615 const HloInstruction* hlo = fused_computation->root_instruction();
4616 if (hlo->opcode() == HloOpcode::kTuple) {
4617 hlo = hlo->operand(index);
4618 }
4619 if (IsReductionFromOrToContiguousDimensions(inst)) {
4620 input_gens.push_back(*fused_emitter->GetGenerator(hlo->operand(0)));
4621 } else {
4622 extra_output_gens.emplace_back(*fused_emitter->GetGenerator(hlo),
4623 index);
4624 }
4625 }
4626 } else {
4627 input_gens.push_back([&](const IrArray::Index& index) {
4628 return operand_ir_arrays[0].EmitReadArrayElement(index, &b_);
4629 });
4630 }
4631
4632 IrArray::Index input_index =
4633 GetUnnormalizedIndex(index, reduction_operand_shape, &b_,
4634 reduction_info.GetKernelMappingScheme());
4635 // Clear the linear index field of the IrArray::Index to enable the use of
4636 // GetElementPointer with array types. This enables the vectorization of
4637 // the computation for different partial results. Use this index if
4638 // 'num_partial_results > 1'.
4639 int num_partial_results = reduction_info.GetNumPartialResults();
4640 auto index_without_linear = IrArray::Index(
4641 input_index.multidim(), reduction_operand_shape, input_index.GetType());
4642
4643 // Emit code to generate the input and perform the reduction computation for
4644 // each reduction instruction.
4645 for (int i = 0; i < reducers.size(); i++) {
4646 llvm::AllocaInst* input_address =
4647 reduction_info.GetReductionInputAddresses()[i];
4648 llvm::AllocaInst* partial_reduction_result_address =
4649 reduction_info.GetPartialResultAddresses()[i];
4650 llvm::Value* const input_ir_value =
4651 input_gens[i](num_partial_results > 1 ? index_without_linear
4652 : input_index)
4653 .ValueOrDie();
4654 Store(input_ir_value, input_address);
4655 llvm::Value* partial_result_address = InBoundsGEP(
4656 partial_reduction_result_address, {b_.getInt32(partial_result_index)});
4657 TF_CHECK_OK(EmitCallToNestedComputation(
4658 *reducers[i], {partial_result_address, input_address},
4659 partial_result_address));
4660 }
4661
4662 // Emit code to generate the output for the non-reduction instructions in the
4663 // fusion, if any.
4664 TF_CHECK_OK(EmitExtraOutputsForReduce(
4665 result_ir_arrays, input_index,
4666 /*use_linear_index=*/num_partial_results == 1, extra_output_gens));
4667 }
4668
EmitThreadId(int64 threads_per_block,llvm::Type * index_ty)4669 llvm::Value* IrEmitterUnnested::EmitThreadId(int64 threads_per_block,
4670 llvm::Type* index_ty) {
4671 // Calculate (y, x) coordinates respectively in the 2D view of thread block,
4672 // defined by (num_thread_y, num_thread_x) from thread_id.
4673 llvm::CallInst* thread_id_raw = gpu::EmitCallToTargetIntrinsic(
4674 gpu::TargetIntrinsicID::kThreadIdx, {}, {}, &b_);
4675 llvm_ir::AddRangeMetadata(0, threads_per_block, thread_id_raw);
4676 return b_.CreateIntCast(thread_id_raw, index_ty,
4677 /*isSigned=*/true, "thread.id.x");
4678 }
4679
EmitThreadIdInfo(int64 threads_per_block,llvm::Type * index_ty,int64 num_threads_x)4680 IrEmitterUnnested::ThreadIdInfo IrEmitterUnnested::EmitThreadIdInfo(
4681 int64 threads_per_block, llvm::Type* index_ty, int64 num_threads_x) {
4682 auto constant = [&](uint64 c) -> llvm::Constant* {
4683 return llvm::ConstantInt::get(index_ty, c);
4684 };
4685 llvm::Value* thread_id = EmitThreadId(threads_per_block, index_ty);
4686 llvm::Value* num_threads_x_v = constant(num_threads_x);
4687 return {
4688 /*thread_id=*/thread_id,
4689 /*thread_id_x=*/b_.CreateURem(thread_id, num_threads_x_v, "thread_id.x"),
4690 /*thread_id_y=*/b_.CreateUDiv(thread_id, num_threads_x_v, "thread_id.y"),
4691 /*lane_id=*/b_.CreateURem(thread_id, constant(kWarpSize), "lane_id")};
4692 }
4693
EmitTilingKernel(const KernelMappingScheme & mapping_scheme,llvm::Type * index_ty,const TileElementGenerator & tile_element_generator)4694 IrEmitterUnnested::TilingKernelInfo IrEmitterUnnested::EmitTilingKernel(
4695 const KernelMappingScheme& mapping_scheme, llvm::Type* index_ty,
4696 const TileElementGenerator& tile_element_generator) {
4697 absl::Span<const int64> dims_in_elems = mapping_scheme.GetDimsInElems();
4698 std::vector<int64> dims_in_blocks = {
4699 CeilOfRatio(dims_in_elems[0], mapping_scheme.GetTileSizeZ()),
4700 CeilOfRatio(dims_in_elems[1], mapping_scheme.GetTileSizeY()),
4701 CeilOfRatio(dims_in_elems[2], mapping_scheme.GetTileSizeX())};
4702 auto constant = [&](uint64 c) -> llvm::Constant* {
4703 return llvm::ConstantInt::get(index_ty, c);
4704 };
4705
4706 IrEmitterUnnested::ThreadIdInfo thread_id_info =
4707 EmitThreadIdInfo(mapping_scheme.GetThreadsPerBlock(), index_ty,
4708 mapping_scheme.GetNumThreadsX());
4709
4710 KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
4711
4712 const IrArray::Index block_coords = [&] {
4713 llvm::Value* block_id = EmitBlockId();
4714 llvm_ir::AddRangeMetadata(0, mapping_scheme.GetNumberOfBlocks(),
4715 llvm::cast<llvm::Instruction>(block_id));
4716 llvm::Value* linear_block_id =
4717 b_.CreateIntCast(block_id, index_ty, /*isSigned=*/true, "block.id.x");
4718 IrArray::Index starting_block(linear_block_id,
4719 ShapeUtil::MakeShapeWithDescendingLayout(
4720 PRED /*arbitrary*/, dims_in_blocks),
4721 &b_);
4722
4723 std::vector<llvm::Value*> multidim = {
4724 b_.CreateMul(starting_block[0], constant(mapping_scheme.GetTileSizeZ()),
4725 "block_origin.z"),
4726 starting_block[1], starting_block[2]};
4727 return IrArray::Index(multidim, dims_in_blocks, index_ty);
4728 }();
4729
4730 std::array<llvm::Value*, 3> output_tile_bounds;
4731 for (int i = kDimY; i < kDimTot; ++i) {
4732 int64 tile_size_for_dim = mapping_scheme.GetTileSizeFor(i);
4733 // Only last row or column may not have full size.
4734 llvm::Value* is_last =
4735 b_.CreateICmpEQ(block_coords[i], constant(dims_in_blocks[i] - 1));
4736 int64 partial_row =
4737 dims_in_elems[i] - (dims_in_blocks[i] - 1) * tile_size_for_dim;
4738 output_tile_bounds[i] =
4739 b_.CreateSelect(is_last, constant(partial_row),
4740 constant(tile_size_for_dim), "tile_bound");
4741 }
4742
4743 IrArray::Index tile_origin = [&] {
4744 std::vector<llvm::Value*> elem_multi_index = block_coords.multidim();
4745 llvm::Type* index_ty = block_coords.GetType();
4746 for (int i = kDimY; i < kDimTot; ++i) {
4747 elem_multi_index[i] = b_.CreateMul(
4748 block_coords[i],
4749 llvm::ConstantInt::get(index_ty, mapping_scheme.GetTileSizeFor(i)),
4750 "tile_origin." + std::to_string(i));
4751 }
4752 return IrArray::Index(elem_multi_index, mapping_scheme.GetDimsInElems(),
4753 index_ty);
4754 }();
4755
4756 auto emit_tile = [&](const IrArray::Index& tile) {
4757 tile_element_generator(thread_id_info, tile, "output",
4758 output_tile_bounds[1], output_tile_bounds[2], &ksl);
4759 };
4760
4761 if (mapping_scheme.GetTileSizeZ() == 1) {
4762 emit_tile(tile_origin);
4763 } else {
4764 llvm::Value* starting_tile_index_for_dim = tile_origin[kDimZ];
4765 llvm::Value* block_size_for_dim = constant(mapping_scheme.GetTileSizeZ());
4766 llvm::Value* block_id_for_dim =
4767 b_.CreateUDiv(starting_tile_index_for_dim, block_size_for_dim);
4768 llvm::Value* last_block_for_dim = constant(dims_in_blocks[kDimZ] - 1);
4769 llvm::Value* last_block_size_for_dim =
4770 constant(dims_in_elems[kDimZ] -
4771 (dims_in_blocks[kDimZ] - 1) * mapping_scheme.GetTileSizeZ());
4772
4773 llvm::Value* num_tiles_in_block =
4774 b_.CreateSelect(b_.CreateICmpEQ(last_block_for_dim, block_id_for_dim),
4775 last_block_size_for_dim, block_size_for_dim);
4776 ksl.For("loop_z",
4777 /*start=*/constant(0),
4778 /*end=*/num_tiles_in_block,
4779 /*step=*/1, [&](llvm::Value* block_dim_induction_var) {
4780 IrArray::Index tile_index = tile_origin.AddOffsetToDim(
4781 block_dim_induction_var, kDimZ, &b_);
4782 emit_tile(tile_index);
4783 });
4784 }
4785 return {output_tile_bounds, tile_origin};
4786 }
4787
EmitSyncThreads()4788 llvm::CallInst* IrEmitterUnnested::EmitSyncThreads() {
4789 return EmitCallToTargetIntrinsic(TargetIntrinsicID::kBarrierId, {}, {}, &b_);
4790 }
4791
4792 // Emits a kernel for the given hlo instruction using a tiled 0-2-1 transpose
4793 // algorithm to improve the memory access patterns for the input parameters
4794 // with a shape that is a 0-2-1 transpose of the output tensor shape. The caller
4795 // is responsible for making sure that it is safe to apply the shared memory
4796 // transpose on the input parameters.
4797 //
4798 //
4799 // For the purpose of tiling, the output tensors have a logical shape of three
4800 // components 0-2-1 while the relevant input parameters have a logical shape
4801 // of three components 0-1-2 in the order major to minor. The x- and y-
4802 // dimensions of the tensors are tiled in square tiles with an edge length
4803 // `kTileSize`. Each thread block of `kTileSize` x `kNumRows` threads
4804 // transposes one tile: each thread copies kTileSize/kNumRows elements from
4805 // the input to a shared memory tile, then the otherwise "regular HLO kernel"
4806 // reads from the shared memory instead of the original input.
4807 //
4808 // This is similar to the following CUDA algorithm in TensorFlow:
4809 // https://goo.gl/MStRV6.
4810 //
4811 // `kTileSize` should usually be same as warp size. We currently choose 32 for
4812 // `kTileSize` and 4 for `kNumRows`. The CUDA algorithm uses 8 for `kNumRows`.
4813 //
4814 // TODO(b/33320379): Here each block transposes 1 tile. It may be more
4815 // efficient to launch fewer blocks so each transposes many tiles.
EmitHlo021Tile(mlir::Operation * op,Thunk * kernel_thunk,const MlirEmitterContext & context,absl::Span<const llvm_ir::IrArray> operand_arrays,absl::Span<const llvm_ir::IrArray> output_arrays,absl::Span<const int64> reduced_output_dims,absl::Span<const int64> tiled_param_ids)4816 void IrEmitterUnnested::EmitHlo021Tile(
4817 mlir::Operation* op, Thunk* kernel_thunk, const MlirEmitterContext& context,
4818 absl::Span<const llvm_ir::IrArray> operand_arrays,
4819 absl::Span<const llvm_ir::IrArray> output_arrays,
4820 absl::Span<const int64> reduced_output_dims,
4821 absl::Span<const int64> tiled_param_ids) {
4822 constexpr int kNumRows = 4;
4823
4824 std::string name = mlir::GetNameFromLoc(op->getLoc());
4825
4826 KernelMappingScheme mapping_scheme(reduced_output_dims,
4827 /*tile_sizes=*/{1, kWarpSize, kWarpSize},
4828 /*num_threads_y=*/kNumRows,
4829 /*num_threads_x=*/kWarpSize,
4830 /*indexing_order=*/kLinearIndexingX,
4831 /*vector_size=*/1,
4832 /*is_row_contiguous=*/false);
4833 LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(),
4834 mapping_scheme.GetThreadsPerBlock());
4835
4836 llvm::Type* index_type =
4837 GetIndexTypeForKernelFromMlir(op, launch_dimensions.launch_bound(), &b_);
4838 std::vector<IrArray> param_arrays;
4839
4840 // For each tiled parameter, cast its input IrArray to the corresponding
4841 // reduced shape and keep the reduced shape live during IR emission.
4842 std::vector<IrArray> param_in_reduced_shape_arrays;
4843 std::vector<llvm::Value*> param_shmem_buffers(context.operand_shapes.size(),
4844 nullptr);
4845
4846 auto get_shared_memory_buffer = [&](llvm::Type* elem_ty,
4847 absl::string_view buffer_name) {
4848 // For Nvidia GPUs, the warp size is 32 threads and the shared memory bank
4849 // is organized into 32-way. We usually use the warp size or a multiplier or
4850 // a the warp size as the size for tiling. This may cause all elements in
4851 // the same column of a tile use the same memory bank and therefore shared
4852 // memory bank conflicts. Adding 1 to the minor dimension of the shared
4853 // memory buffer can reduce such shared memory bank conflicts.
4854 llvm::Type* buffer_type = llvm::ArrayType::get(
4855 llvm::ArrayType::get(elem_ty, mapping_scheme.GetTileSizeX() + 1),
4856 mapping_scheme.GetTileSizeY());
4857 return llvm_ir::AllocateSharedMemoryTile(b_.GetInsertBlock()->getModule(),
4858 buffer_type, buffer_name);
4859 };
4860
4861 for (int64 id = 0; id < context.operand_shapes.size(); id++) {
4862 const Shape& param_shape = context.operand_shapes[id];
4863 param_arrays.push_back(operand_arrays[id]);
4864
4865 if (absl::c_linear_search(tiled_param_ids, id)) {
4866 param_shmem_buffers[id] = get_shared_memory_buffer(
4867 llvm_ir::PrimitiveTypeToIrType(param_shape.element_type(), module_),
4868 IrName(name, StrCat("tile", id)));
4869 VLOG(3) << "Added shmem buffer for parameter " << id << ": "
4870 << llvm_ir::DumpToString(*param_shmem_buffers[id]);
4871 Shape reduced_shape = ShapeUtil::MakeShapeWithDescendingLayout(
4872 param_shape.element_type(), Permute(reduced_output_dims, {0, 2, 1}));
4873 param_in_reduced_shape_arrays.push_back(
4874 param_arrays[id].CastToShape(reduced_shape, &b_));
4875 } else {
4876 param_in_reduced_shape_arrays.push_back(IrArray());
4877 }
4878 }
4879
4880 EmitElementFunction element_generator =
4881 [&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc,
4882 llvm::Value* x_loc, int64 x_iter_num) {
4883 if (auto copy = mlir::dyn_cast<mlir::lmhlo::CopyOp>(op)) {
4884 CHECK_EQ(1, context.output_shapes.size());
4885 EmitTileElementForCopy(context.output_shapes[0], output_arrays[0],
4886 index, mapping_scheme, y_loc, x_loc,
4887 param_shmem_buffers);
4888 } else if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
4889 EmitTileElementForFusion(fusion, operand_arrays, output_arrays, index,
4890 mapping_scheme, y_loc, x_loc,
4891 param_shmem_buffers);
4892 } else {
4893 LOG(FATAL) << "Unexpected op: " << MlirToString(op);
4894 }
4895 };
4896
4897 TileElementGenerator tile_generator =
4898 [&](const ThreadIdInfo& thread_id_info, const IrArray::Index& index,
4899 const string& loop_name, llvm::Value* tile_height,
4900 llvm::Value* tile_width, KernelSupportLibrary* ksl) {
4901 // If shared memory transpose is needed, wait for all threads to reach
4902 // this point, lest we copy a value from tile to output before the other
4903 // thread copies it from input to tile. This is `__syncthreads` in CUDA.
4904 if (!tiled_param_ids.empty()) {
4905 // Calculate the input tile origin from the output tile origin.
4906 const IrArray::Index input_tile_origin(
4907 Permute(index.multidim(), {0, 2, 1}),
4908 Permute(index.dims(), {0, 2, 1}), index.GetType());
4909
4910 // Copy input parameter values to shared memory buffers:
4911 // tile[thread_id_y, thread_id_x] = input[index]
4912 // Note that tile_width and tile_height are flipped here because we
4913 // are reading a transposed tile.
4914 EmitTile(mapping_scheme, input_tile_origin, "input", ksl,
4915 thread_id_info, tile_width, tile_height,
4916 [&](const IrArray::Index& index, llvm::Value* y_loc,
4917 llvm::Value* x_loc, int64 /*x_iter_num*/) {
4918 for (int64 id : tiled_param_ids) {
4919 IrArray& input_in_logical_shape =
4920 param_in_reduced_shape_arrays.at(id);
4921
4922 llvm::Value* shmem_buffer = param_shmem_buffers.at(id);
4923 llvm::Value* zero =
4924 llvm::ConstantInt::get(index_type, 0);
4925 // TODO(jlebar): Add AA metadata to this store. Tile
4926 // buffers are global variables, so LLVM can't infer much
4927 // about it.
4928 auto value = input_in_logical_shape.EmitReadArrayElement(
4929 index, &b_, "input_element");
4930 auto addr = GEP(shmem_buffer, {zero, y_loc, x_loc});
4931 Store(value, addr);
4932 }
4933 });
4934
4935 // Wait for all threads to reach this point using `__syncthreads` in
4936 // CUDA.
4937 EmitSyncThreads();
4938 }
4939
4940 EmitTile(mapping_scheme, index, loop_name, ksl, thread_id_info,
4941 tile_height, tile_width, element_generator);
4942 bool block_contains_multi_tiles = mapping_scheme.GetTileSizeZ() > 1;
4943
4944 // If a tile block contains multiple tiles and shared memory buffers are
4945 // used, we need to wait for all threads to finish using the shared
4946 // memory buffer for the current tile before we move on to process the
4947 // next tile and overwrite the shared memory buffers.
4948 if (block_contains_multi_tiles && !tiled_param_ids.empty()) {
4949 EmitSyncThreads();
4950 }
4951 };
4952
4953 // For multioutput fusion, one thread needs to output a tuple
4954 // with pointers to all the individual outputs. We could do this
4955 // at any point in the kernel, but we do it at the beginning in
4956 // the hopes of reducing register pressure, since we touch
4957 // threadIdx.x and blockIdx.x at the beginning of the kernel
4958 // *anyway*.
4959 if (output_arrays.size() > 1) {
4960 KernelSupportLibrary{&b_}.If("emit_mof_tuple", IsBlock0Thread0(&b_), [&] {
4961 llvm_ir::EmitTuple(output_arrays.back(),
4962 output_arrays.subspan(0, output_arrays.size() - 1),
4963 &b_);
4964 });
4965 }
4966
4967 EmitTilingKernel(mapping_scheme, index_type, tile_generator);
4968 UpdateLaunchDimensions(launch_dimensions, kernel_thunk,
4969 ir_emitter_context_->llvm_module());
4970 }
4971
4972 namespace {
4973
4974 // A recursive function to inspect the users of a parameter to determine
4975 // whether it's safe for a parameter to participate in a shared-memory
4976 // transpose.
4977 //
4978 // Consider a fusion parameter P for which we might want to use a shmem
4979 // transpose. If we do, we use a GPU thread block to preload a tile of P with
4980 // indices [z, y..y+31, x..x+31] to compute an output tile with the same indices
4981 // cooperatively, where z, y, x are the indices for the normalized input/output
4982 // tensor (see the document for FindTranspose021 for the definition of
4983 // normalized tensor for 0-2-1 transpose). This shmem transpose implementation
4984 // requires that the computation of the output tile only read elements within
4985 // the preload tile. If this is not true, we can't use a shmem transpose for P.
4986 //
4987 // If the computation of output element [z, y, x] only requires the element of
4988 // P with the same indices, the shmem transpose implementation can be applied
4989 // to P safely. This is a sufficient but not necessary condition. We check all
4990 // the transitive users of P to see if we can find a user that may cause an
4991 // exception to the situation. If such a user is not found, we conclude that P
4992 // is safe for shmem transpose.
4993 //
4994 // This is trivially true for elementwise operations and some "data-movement"
4995 // ops like kTuple. However, it's not true for operations that can change the
4996 // dimensions of the inputs (e.g. pad, slice) and bitcast operation.
4997 // For example:
4998 //
4999 // fused_computation {
5000 // param_0 = f32[64,64]{1,0} parameter(0)
5001 // ROOT bitcast = f32[64,64]{0,1} bitcast(param_0)
5002 // }
5003 // The output element at logical address [0, 63] depends on the input element
5004 // at logical address [63, 0], which would not be within the shared-memory
5005 // block.
5006 //
5007 // TODO(bixia): In order to extend this for kInput fusion, that is reduction
5008 // with transpose, we only need to end the use-chain checking with the input of
5009 // a reduce operations. In this case, the above description on "output" apply
5010 // to the result of such a use-chain, which provides the input to the reduce
5011 // operation.
IsInstructionSafeForShmemTranspose(mlir::Operation * op)5012 bool IsInstructionSafeForShmemTranspose(mlir::Operation* op) {
5013 if (mlir::isa<mlir::TensorStoreOp>(op)) {
5014 return true;
5015 }
5016
5017 HloOpcode opcode;
5018 if (mlir::isa<mlir::TensorLoadOp>(op)) {
5019 opcode = HloOpcode::kParameter;
5020 } else {
5021 opcode = *MhloToHloOpcode(op);
5022 }
5023 if (HloInstruction::IsOpElementwise(opcode)) {
5024 for (mlir::Value v : op->getResults()) {
5025 for (mlir::OpOperand use : v.getUsers()) {
5026 if (!IsInstructionSafeForShmemTranspose(use.getOwner())) {
5027 return false;
5028 }
5029 }
5030 }
5031 return true;
5032 }
5033
5034 switch (opcode) {
5035 // Non-elementwise instructions that don't cause the shmem transpose
5036 // to be unsafe, including the instructions that don't currently fuse.
5037 case HloOpcode::kGetDimensionSize:
5038 // The result of the operation doesn't rely on the content of the
5039 // tensor. As such, there is no need to further inspect its users.
5040 return true;
5041 case HloOpcode::kGetTupleElement:
5042 case HloOpcode::kMap:
5043 case HloOpcode::kParameter:
5044 case HloOpcode::kTuple:
5045 case HloOpcode::kTupleSelect:
5046 for (mlir::Value v : op->getResults()) {
5047 for (mlir::OpOperand use : v.getUsers()) {
5048 if (!IsInstructionSafeForShmemTranspose(use.getOwner())) {
5049 return false;
5050 }
5051 }
5052 }
5053 return true;
5054
5055 default:
5056 return false;
5057 }
5058 }
5059
5060 // Given a group of input parameters that are 0-2-1 transpose of the outputs of
5061 // a fusion kernel, returns the input parameters that are safe for the shared
5062 // memory transpose implementation.
5063 //
5064 // When a tile based shared memory transpose is used to implement an input with
5065 // 0-2-1 transpose, we preload a tile of the input elements
5066 // [z, y..y+31, x..x+31] to compute the output tile elements of the same
5067 // indices. Preloading the input tile this way is only safe when the computation
5068 // of the output tile elements do not need any input element outside the
5069 // preloaded tile. We inspect all the transitive users of the input parameter
5070 // up to the fusion root instruction to see if we can find any instruction
5071 // that can make preloading the input tile unsafe.
FilterInputsForShmemTranspose(mlir::lmhlo::FusionOp fusion,std::vector<int64> input_ids)5072 std::vector<int64> FilterInputsForShmemTranspose(mlir::lmhlo::FusionOp fusion,
5073 std::vector<int64> input_ids) {
5074 std::vector<mlir::Value> params = ToStdVector(fusion.getFusionParameters());
5075
5076 std::vector<int64> filtered_input_ids;
5077 for (int64 input_id : input_ids) {
5078 mlir::Value input = params.at(input_id);
5079 if (IsInstructionSafeForShmemTranspose(input.getDefiningOp())) {
5080 filtered_input_ids.push_back(input_id);
5081 }
5082 }
5083 return filtered_input_ids;
5084 }
5085
5086 } // namespace
5087
CheckAndEmitHloWithTile021(MlirEmitterInput input)5088 StatusOr<bool> IrEmitterUnnested::CheckAndEmitHloWithTile021(
5089 MlirEmitterInput input) {
5090 CHECK((mlir::isa<mlir::lmhlo::FusionOp, mlir::lmhlo::CopyOp>(input.op)));
5091
5092 MlirEmitterContext context;
5093 context.SetOperation(input.op);
5094
5095 // If the output_shape is reduced to 021 shape, find all the parameters of
5096 // the HLO that are in the corresponding 012 shape.
5097 std::vector<int64> params_012;
5098 optional<std::vector<int64>> reduced_dims_021;
5099 for (int64 operand_idx = 0; operand_idx < context.operand_shapes.size();
5100 ++operand_idx) {
5101 const Shape& operand_shape = context.operand_shapes[operand_idx];
5102 auto find_transpose_result =
5103 ShapeUtil::FindTranspose021(operand_shape, context.output_shapes[0]);
5104 if (!find_transpose_result.has_value()) {
5105 continue;
5106 }
5107 const std::vector<int64>& curr_reduced_dims_021 = *find_transpose_result;
5108 if (!reduced_dims_021.has_value()) {
5109 reduced_dims_021 = curr_reduced_dims_021;
5110 }
5111 if (!absl::c_equal(*reduced_dims_021, curr_reduced_dims_021)) {
5112 // There is more than one possible transpose. Instead of picking one
5113 // transpose, we simply give up here.
5114 return false;
5115 }
5116 params_012.push_back(operand_idx);
5117 }
5118
5119 if (!reduced_dims_021.has_value()) {
5120 return false;
5121 }
5122
5123 if ((*reduced_dims_021)[1] < kMinDimensionToTransposeTiled ||
5124 (*reduced_dims_021)[2] < kMinDimensionToTransposeTiled) {
5125 return false;
5126 }
5127
5128 if (auto fusion_op = mlir::dyn_cast<mlir::lmhlo::FusionOp>(input.op)) {
5129 params_012 = FilterInputsForShmemTranspose(fusion_op, params_012);
5130 if (params_012.empty()) {
5131 return false;
5132 }
5133 }
5134
5135 // Each of our shared memory tiles has 32*33 elements (so ~4kb, if the
5136 // elements are of size 4 bytes), and CUDA has an architectural limit of
5137 // 48kb shared memory per SM. (This is increased to 96kb in Volta, but we
5138 // don't use this, in part because it eats into our L1 cache space.)
5139 //
5140 // For correctness we need to ensure that we don't make more than 48kb worth
5141 // of shmem tiles per block. And for performance, we'd probably like to use
5142 // significantly less, so that we can fit more than one block at a time on a
5143 // gpu core.
5144 //
5145 // We say without benchmarks that we want at least 3 threads/block,
5146 // corresponding to 3 shmem tiles if the elements are 32 bits wide. We
5147 // choose which params get the shmem transpose treatment arbitrarily; it's
5148 // not clear if there's a Right Choice.
5149 //
5150 // This is only sound if tiled transposes are the only place where we use
5151 // shared memory in fusions. If in the future other fusible ops use shared
5152 // memory, we'll have to adjust this heuristic.
5153 constexpr int kMinBlocksPerCore = 3;
5154 constexpr int64 kShmemPerCore = 48 * 1024;
5155 int64 shmem_used = 0;
5156 for (int64 i = 0; i < params_012.size(); ++i) {
5157 const Shape& operand_shape = context.operand_shapes[params_012[i]];
5158 shmem_used +=
5159 32 * 33 *
5160 ShapeUtil::ByteSizeOfPrimitiveType(operand_shape.element_type());
5161
5162 if (kMinBlocksPerCore * shmem_used > kShmemPerCore) {
5163 // Erase this element and everything after it from params_012.
5164 params_012.resize(i);
5165 break;
5166 }
5167 }
5168
5169 if (params_012.empty()) {
5170 return false;
5171 }
5172
5173 std::vector<llvm_ir::IrArray> ir_arrays;
5174 TF_ASSIGN_OR_RETURN(std::unique_ptr<KernelThunk> kernel_thunk,
5175 BuildKernelThunkForMlir(input.op, input.thunk_info,
5176 input.extra_slice, &ir_arrays));
5177 EmitHlo021Tile(
5178 input.op, kernel_thunk.get(), context,
5179 absl::MakeSpan(ir_arrays).subspan(0, context.operand_shapes.size()),
5180 absl::MakeSpan(ir_arrays).subspan(context.operand_shapes.size()),
5181 *reduced_dims_021, params_012);
5182 AddThunkToThunkSequence(std::move(kernel_thunk));
5183 return true;
5184 }
5185
5186 namespace {
5187
5188 // Returns true if all the transitive users of hlo before hitting users in
5189 // use_chain_endings are elementwise operations.
AreUsersElementwise(mlir::Value value,const absl::flat_hash_set<mlir::Operation * > & use_chain_endings)5190 bool AreUsersElementwise(
5191 mlir::Value value,
5192 const absl::flat_hash_set<mlir::Operation*>& use_chain_endings) {
5193 return absl::c_all_of(value.getUsers(), [&](mlir::OpOperand use) {
5194 mlir::Operation* user = use.getOwner();
5195 CHECK_EQ(1, user->getNumResults());
5196 return use_chain_endings.count(user) ||
5197 (HloInstruction::IsOpElementwise(*MhloToHloOpcode(user)) &&
5198 AreUsersElementwise(user->getResult(0), use_chain_endings));
5199 });
5200 }
5201
5202 // Returns the number of fusion inputs that have the same dimension as the
5203 // given shape, and involve in only elementwise operations.
NumInputsInvolveInOnlyElementwiseOps(mlir::lmhlo::FusionOp fusion,const Shape & op_shape,const absl::flat_hash_set<mlir::Operation * > & use_chain_endings)5204 int64 NumInputsInvolveInOnlyElementwiseOps(
5205 mlir::lmhlo::FusionOp fusion, const Shape& op_shape,
5206 const absl::flat_hash_set<mlir::Operation*>& use_chain_endings) {
5207 return absl::c_count_if(
5208 fusion.getFusionParameters(), [&](mlir::Value parameter) {
5209 Shape parameter_shape = TypeToShape(parameter.getType());
5210 return ShapeUtil::SameDimensions(op_shape, parameter_shape) &&
5211 AreUsersElementwise(parameter, use_chain_endings);
5212 });
5213 }
5214
5215 // Returns the number of fusion inputs that have more elements than the given
5216 // shape.
NumInputsWithMoreElementsThan(mlir::lmhlo::FusionOp fusion,const Shape & shape)5217 int64 NumInputsWithMoreElementsThan(mlir::lmhlo::FusionOp fusion,
5218 const Shape& shape) {
5219 int64 num_elements = ShapeUtil::ElementsIn(shape);
5220 return absl::c_count_if(
5221 fusion.getFusionParameters(), [&](mlir::Value parameter) {
5222 Shape parameter_shape = TypeToShape(parameter.getType());
5223 return ShapeUtil::ElementsIn(parameter_shape) > num_elements;
5224 });
5225 }
5226
5227 // The benefit of unrolling a kInput fusion that is a column reduction comes
5228 // from the vectorization of non-reduction fusion outputs and fusion inputs.
5229 // On the other hand, unrolling can also introduce factors that can cause
5230 // the kernel to run slower. This routine uses a simple heuristic to estimate
5231 // the benefit as well as the overhead of unrolling in order to decide whether
5232 // unrolling is beneficial for the given kInput fusion.
IsUnrollingColumnReductionBeneficial(mlir::Operation * unnested_hlo,const Shape & input_shape,int64 num_kept_minor)5233 bool IsUnrollingColumnReductionBeneficial(mlir::Operation* unnested_hlo,
5234 const Shape& input_shape,
5235 int64 num_kept_minor) {
5236 // TODO(b/122468062): Need further investigate to see whether we can
5237 // remove the constraint on IsPowerOfTwo.
5238 if (!IsPowerOfTwo(static_cast<uint64>(num_kept_minor))) {
5239 return false;
5240 }
5241
5242 if (IsReductionFromOrToContiguousDimensions(unnested_hlo)) {
5243 return true;
5244 }
5245
5246 auto fusion = mlir::cast<mlir::lmhlo::FusionOp>(unnested_hlo);
5247 int64 can_be_vectorized = 0;
5248 int64 cannot_be_vectorized = 0;
5249 auto fusion_results = ToStdVector(fusion.getFusionResults());
5250 absl::flat_hash_set<mlir::Operation*> use_chain_endings;
5251 if (fusion_results.size() == 1) {
5252 if (IsReductionFromOrToContiguousDimensions(
5253 fusion_results[0].getDefiningOp())) {
5254 use_chain_endings.insert(fusion_results[0].getDefiningOp());
5255 // Atomic.add of the reduction result can't be vectorized.
5256 cannot_be_vectorized++;
5257 }
5258 } else {
5259 for (mlir::Value result : fusion_results) {
5260 if (IsReductionFromOrToContiguousDimensions(result.getDefiningOp())) {
5261 // Atomic.add of the reduction result can't be vectorized.
5262 cannot_be_vectorized++;
5263 } else {
5264 // Write of the non-reduction result can be vectorized.
5265 can_be_vectorized++;
5266 }
5267 use_chain_endings.insert(result.getDefiningOp());
5268 }
5269 }
5270 // Fusion inputs that have the same dimension as the reduce input and
5271 // only involve in elementwise operations can be vectorized.
5272 can_be_vectorized += NumInputsInvolveInOnlyElementwiseOps(fusion, input_shape,
5273 use_chain_endings);
5274 // Fusion inputs with more elements than the reduce op input must participate
5275 // in non-elementwise operations and we assume that they are not vectorizable
5276 // for the purpose of estimating the benefit of unrolling. If the kernel is
5277 // unrolled even with such an assumption, and the accesses to those inputs
5278 // turn out to be vectorizable, the compiler will still vectorize them.
5279 cannot_be_vectorized += NumInputsWithMoreElementsThan(fusion, input_shape);
5280 return can_be_vectorized >= cannot_be_vectorized;
5281 }
5282
NearestPowerOfTwo(int64 v)5283 int64 NearestPowerOfTwo(int64 v) {
5284 if (v < 0) {
5285 return 0;
5286 }
5287 int64 upper = tensorflow::NextPowerOfTwo64(v);
5288 int64 lower = upper >> 1;
5289 return upper - v < v - lower ? upper : lower;
5290 }
5291
5292 } // namespace
5293
ComputeReductionCodegenInfo(mlir::Operation * unnested_hlo,mlir::Operation * first_reduce)5294 ReductionCodegenInfo IrEmitterUnnested::ComputeReductionCodegenInfo(
5295 mlir::Operation* unnested_hlo, mlir::Operation* first_reduce) {
5296 Shape input_shape = TypeToShape(first_reduce->getOperand(0).getType());
5297 ReductionDimensions reduction_dimensions =
5298 GetReductionKindAndContiguousComponents(first_reduce);
5299 VLOG(10) << "is_row_reduction " << reduction_dimensions.is_row_reduction
5300 << " " << reduction_dimensions.dimensions[0] << " "
5301 << reduction_dimensions.dimensions[1] << " "
5302 << reduction_dimensions.dimensions[2];
5303 auto get_dtype_bits = [](mlir::Value i) {
5304 // TODO(timshen): may not be efficient.
5305 return primitive_util::BitWidth(TypeToShape(i.getType()).element_type());
5306 };
5307
5308 // For fusion with multiple inputs, use the smallest input dtype to
5309 // select the reduction_tiling.
5310 int smallest_input_dtype_bits = get_dtype_bits(first_reduce->getOperand(0));
5311
5312 for (mlir::Value operand : GetHloOperands(unnested_hlo)) {
5313 smallest_input_dtype_bits =
5314 std::min(get_dtype_bits(operand), smallest_input_dtype_bits);
5315 }
5316 std::array<int64, 3> reduction_tiling =
5317 GetReductionTiling(reduction_dimensions, smallest_input_dtype_bits,
5318 ir_emitter_context_->cuda_compute_capability());
5319
5320 int64 num_threads_y = reduction_dimensions.is_row_reduction ? 1 : kWarpSize;
5321 int64 num_threads_x = [&] {
5322 if (reduction_dimensions.is_row_reduction) {
5323 // Use 512 as default block size (threads per block) for row reductions.
5324 // For multi-output fusions, reduce the block size further to decrease
5325 // register pressure when multiple outputs are computed by each thread.
5326 int64 fan_out = 1;
5327 if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(unnested_hlo)) {
5328 fan_out = fusion.getFusionResults().size();
5329 }
5330
5331 // 64 is the general advice as the smallest block sizes.
5332 // Moreover, XLA:GPU emitters need at least 32 threads at some places.
5333 int64 max_block_size = std::max(64LL, 512LL / NearestPowerOfTwo(fan_out));
5334 return std::min(
5335 max_block_size,
5336 RoundUpToNearest(CeilOfRatio(reduction_dimensions.dimensions[2],
5337 reduction_tiling[2]),
5338 kWarpSize));
5339 }
5340 return kWarpSize;
5341 }();
5342
5343 bool tile_fit = reduction_dimensions.dimensions[kDimX] %
5344 (reduction_tiling[2] * num_threads_x) ==
5345 0;
5346
5347 int cc_major = 0;
5348 if (ir_emitter_context_->cuda_compute_capability()) {
5349 cc_major = ir_emitter_context_->cuda_compute_capability()->cc_major;
5350 }
5351
5352 int num_partial_results = 1;
5353 KernelMappingScheme::IndexingOrder indexing_order = [&]() {
5354 if (reduction_dimensions.is_row_reduction &&
5355 // P100, only try to vectorize+coales memory access when the
5356 // tile size fits exactly and dtypes <= 32 bits
5357 ((cc_major == 6 && smallest_input_dtype_bits <= 32 && tile_fit) ||
5358 // On V100, only try to vectorize+coales memory access for
5359 // rows of even size. For odd row sizes, every other row
5360 // isn't aligned, so it can't be vectorized.
5361 (cc_major >= 7 && reduction_dimensions.dimensions[2] % 2 == 0))) {
5362 return kStridedLinearIndexingX;
5363 } else if (!reduction_dimensions.is_row_reduction &&
5364 IsUnrollingColumnReductionBeneficial(
5365 unnested_hlo, input_shape,
5366 reduction_dimensions.dimensions[2])) {
5367 num_partial_results = 2;
5368 reduction_tiling[2] *= num_partial_results;
5369 return kLinearIndexingX;
5370 } else {
5371 return kStridedIndexingX;
5372 }
5373 }();
5374
5375 int vector_size = 1;
5376 if (indexing_order == kStridedLinearIndexingX) {
5377 // Assuming XLA will perform the unrolling and LLVM will vectorize,
5378 // disable the unroll for the cases that LLVM doesn't vectorize.
5379 if (reduction_dimensions.dimensions[2] % 2 == 0 &&
5380 !MayPreventVectorization(unnested_hlo)) {
5381 vector_size = 2;
5382 } else {
5383 indexing_order = kStridedIndexingX;
5384 }
5385 }
5386 KernelMappingScheme mapping_scheme(
5387 reduction_dimensions.dimensions,
5388 {reduction_tiling[0], reduction_tiling[1] * num_threads_y,
5389 reduction_tiling[2] * num_threads_x},
5390 num_threads_y, num_threads_x, indexing_order, vector_size);
5391 return ReductionCodegenInfo(mapping_scheme, num_partial_results,
5392 reduction_dimensions.is_row_reduction);
5393 }
5394
EmitIRForReduction(mlir::Operation * unnested_hlo,absl::Span<const int> instr_index_group,HloComputation * fused_computation,FusedIrEmitter * fused_emitter,absl::Span<const llvm_ir::IrArray> operand_ir_arrays,absl::Span<const llvm_ir::IrArray> result_ir_arrays,ReductionCodegenInfo * reduction_info,const Shape & input_shape)5395 void IrEmitterUnnested::EmitIRForReduction(
5396 mlir::Operation* unnested_hlo, absl::Span<const int> instr_index_group,
5397 HloComputation* fused_computation, FusedIrEmitter* fused_emitter,
5398 absl::Span<const llvm_ir::IrArray> operand_ir_arrays,
5399 absl::Span<const llvm_ir::IrArray> result_ir_arrays,
5400 ReductionCodegenInfo* reduction_info, const Shape& input_shape) {
5401 std::vector<HloComputation*> reducers;
5402 for (auto index : instr_index_group) {
5403 auto reduce = GetReduceFromUnnestedMlir(unnested_hlo, index);
5404 if (!IsReductionFromOrToContiguousDimensions(reduce)) {
5405 continue;
5406 }
5407 if (auto unnested_reduce = mlir::dyn_cast<mlir::lmhlo::ReduceOp>(reduce)) {
5408 reducers.push_back(
5409 *GetOrCreateSubComputationFromRegion(&unnested_reduce.body(),
5410 /*is_fusion=*/false));
5411 } else if (auto nested_reduce =
5412 mlir::dyn_cast<mlir::mhlo::ReduceOp>(reduce)) {
5413 HloInstruction* root = fused_computation->root_instruction();
5414 if (root->opcode() == HloOpcode::kTuple) {
5415 root = root->mutable_operand(index);
5416 } else {
5417 CHECK_EQ(0, index);
5418 }
5419 reducers.push_back(root->to_apply());
5420 } else {
5421 LOG(FATAL) << "Unexpected reduce op: " << MlirToString(reduce);
5422 }
5423 }
5424 CHECK(!reducers.empty()) << " expect at least one reduce instructions.";
5425
5426 const KernelMappingScheme& mapping_scheme =
5427 reduction_info->GetKernelMappingScheme();
5428 LaunchDimensions launch_dimensions(mapping_scheme.GetNumberOfBlocks(),
5429 mapping_scheme.GetThreadsPerBlock());
5430 llvm::Type* index_ty = GetIndexTypeForKernelFromMlir(
5431 unnested_hlo, launch_dimensions.launch_bound(), &b_);
5432 EmitPrologueForReduction(unnested_hlo, instr_index_group, fused_computation,
5433 fused_emitter, operand_ir_arrays, result_ir_arrays,
5434 reduction_info);
5435
5436 EmitElementFunction emit_reduction_tile =
5437 [&](const llvm_ir::IrArray::Index& index, llvm::Value* y_loc,
5438 llvm::Value* x_loc, int64 x_iter_num) {
5439 EmitTileElementForReduction(
5440 unnested_hlo, input_shape, instr_index_group, fused_computation,
5441 fused_emitter, operand_ir_arrays, result_ir_arrays, reducers, index,
5442 *reduction_info, x_iter_num);
5443 };
5444
5445 TilingKernelInfo tiling_kernel_info = EmitTilingKernel(
5446 mapping_scheme, index_ty,
5447 [&](const ThreadIdInfo& thread_id_info, const IrArray::Index& index,
5448 const string& loop_name, llvm::Value* tile_height,
5449 llvm::Value* tile_width, KernelSupportLibrary* ksl) {
5450 EmitTile(reduction_info->GetKernelMappingScheme(), index, loop_name,
5451 ksl, thread_id_info, tile_height, tile_width,
5452 emit_reduction_tile);
5453 });
5454 EmitEpilogueForReduction(index_ty, unnested_hlo, instr_index_group,
5455 result_ir_arrays, reducers, *reduction_info,
5456 tiling_kernel_info);
5457 }
5458
5459 namespace {
5460
5461 // Returns whether the `instr` is either a constant, a scalar, or a
5462 // broadcasted constant/scalar.
IsBroadcastedConstantOrScalar(const HloInstruction & instr)5463 bool IsBroadcastedConstantOrScalar(const HloInstruction& instr) {
5464 return instr.IsConstant() || ShapeUtil::IsScalar(instr.shape()) ||
5465 (HloOpcode::kBroadcast == instr.opcode() &&
5466 (instr.operand(0)->IsConstant() ||
5467 ShapeUtil::IsScalar(instr.operand(0)->shape())));
5468 }
5469
5470 // Divides `num_reduces` reduces into groups. Different groups will be executed
5471 // in parallel. Generally speaking, we'd like to run the reduce instructions
5472 // in parallel without incurring too much recomputation overhead. The current
5473 // heuristic is to place reduce instructions who share nothing or only
5474 // (broadcasted) scalars/constants into different groups; otherwise, they are
5475 // placed in the same group. Non-reduce instructions always go with the reduce
5476 // instructions into the same group so long as they share any predecessors.
DivideOutputInstructionsIntoGroups(HloComputation * fused_computation,int num_reduces)5477 std::vector<std::vector<int>> DivideOutputInstructionsIntoGroups(
5478 HloComputation* fused_computation, int num_reduces) {
5479 CHECK_NE(0, num_reduces);
5480 if (num_reduces == 1) {
5481 return {{0}};
5482 }
5483
5484 std::vector<tensorflow::UnionFind<HloInstruction*>> disjoint_sets(
5485 num_reduces);
5486 for (size_t i = 0; i < num_reduces; ++i) {
5487 disjoint_sets[i].Get() =
5488 fused_computation->root_instruction()->mutable_operand(i);
5489 }
5490
5491 std::unique_ptr<HloReachabilityMap> reachability_map =
5492 HloReachabilityMap::Build(fused_computation);
5493 for (auto* instr : fused_computation->instructions()) {
5494 std::vector<int64> reached_output_ids;
5495 for (size_t oid = 0; oid < num_reduces; ++oid) {
5496 auto reduce = fused_computation->root_instruction()->mutable_operand(oid);
5497 if (HloOpcode::kReduce == reduce->opcode() &&
5498 (IsBroadcastedConstantOrScalar(*instr))) {
5499 // Do not group output reduce instructions through broadcasted
5500 // constants or scalars, as the recomputation should be acceptable.
5501 VLOG(3) << "Skip broadcasted constant or scalar " << instr->ToString();
5502 continue;
5503 }
5504 // Now group output instructions if they have common predecessors.
5505 if (reachability_map->IsReachable(instr, reduce)) {
5506 VLOG(3) << "Reaching " << reduce->ToString() << " from "
5507 << instr->ToString();
5508 reached_output_ids.push_back(oid);
5509 }
5510 }
5511 for (size_t j = 1; j < reached_output_ids.size(); ++j) {
5512 disjoint_sets[reached_output_ids[0]].Merge(
5513 &disjoint_sets[reached_output_ids[j]]);
5514 }
5515 }
5516 // Place output instructions in the same set into the same group.
5517 absl::flat_hash_map<HloInstruction*, std::vector<int>> groups;
5518 for (size_t oid = 0; oid < num_reduces; ++oid) {
5519 groups[disjoint_sets[oid].Get()].push_back(oid);
5520 }
5521
5522 std::vector<std::vector<int>> ret;
5523 absl::c_for_each(
5524 groups, [&](auto& iter) { ret.emplace_back(std::move(iter.second)); });
5525 return ret;
5526 }
5527
5528 } // namespace
5529
EmitReductionFromOrToContiguousDimensions(MlirEmitterInput mlir_input)5530 Status IrEmitterUnnested::EmitReductionFromOrToContiguousDimensions(
5531 MlirEmitterInput mlir_input) {
5532 mlir::Operation* unnested_hlo = mlir_input.op;
5533 auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(unnested_hlo);
5534
5535 int num_reduces = 1;
5536 if (fusion) {
5537 num_reduces = fusion.getFusionResults().size();
5538 }
5539
5540 bool returns_tuple = num_reduces > 1;
5541 VLOG(10) << "Emitting reduction to vector " << MlirToString(unnested_hlo);
5542
5543 // Build an initializer thunk to initialize each reduction output.
5544 std::vector<std::unique_ptr<Thunk>> thunks;
5545 for (int i = 0; i < num_reduces; ++i) {
5546 mlir::Operation* output_instruction =
5547 GetReduceFromUnnestedMlir(unnested_hlo, i);
5548 if (!IsReductionFromOrToContiguousDimensions(output_instruction)) {
5549 continue;
5550 }
5551
5552 if (fusion) {
5553 TF_ASSIGN_OR_RETURN(std::unique_ptr<Thunk> initializer_thunk,
5554 BuildFusedInitializerThunkForMlir(fusion, i));
5555 thunks.push_back(std::move(initializer_thunk));
5556 } else {
5557 auto reduce = mlir::cast<mlir::lmhlo::ReduceOp>(output_instruction);
5558
5559 TF_RET_CHECK(!returns_tuple);
5560 TF_ASSIGN_OR_RETURN(
5561 std::unique_ptr<Thunk> initializer_thunk,
5562 BuildInitializerThunkForMlir(reduce, reduce.init_values()[0],
5563 reduce.out()[0]));
5564 thunks.push_back(std::move(initializer_thunk));
5565 }
5566 }
5567
5568 // Build a kernel thunk to compute all the outputs.
5569 mlir::Operation* first_reduce = nullptr;
5570 for (int i = 0; i < num_reduces; ++i) {
5571 if (IsReductionFromOrToContiguousDimensions(
5572 GetReduceFromUnnestedMlir(unnested_hlo, i))) {
5573 first_reduce = GetReduceFromUnnestedMlir(unnested_hlo, i);
5574 break;
5575 }
5576 }
5577 CHECK(first_reduce) << MlirToString(unnested_hlo);
5578 if (num_reduces > 1) {
5579 for (int i = 0; i < num_reduces; i++) {
5580 auto candidate = mlir::dyn_cast<mlir::mhlo::ReduceOp>(
5581 GetReduceFromUnnestedMlir(unnested_hlo, i));
5582 if (candidate &&
5583 !IsFusedReductionOutputConsistent(
5584 candidate, mlir::cast<mlir::mhlo::ReduceOp>(first_reduce))) {
5585 return InternalError("Inconsistent reduction fusion outputs");
5586 }
5587 }
5588 }
5589 Shape input_shape = TypeToShape(first_reduce->getOperand(0).getType());
5590 // The layout of a reduction input is either set by LayoutAssignment for
5591 // unnested kReduce or by InstructionFusion for fused kReduce.
5592 CHECK(input_shape.has_layout()) << "LayoutAssignment or InstructionFusion "
5593 "doesn't set the input layout of "
5594 << MlirToString(first_reduce);
5595
5596 std::vector<llvm_ir::IrArray> ir_arrays;
5597 TF_ASSIGN_OR_RETURN(std::unique_ptr<KernelThunk> kernel_thunk,
5598 BuildKernelThunkForMlir(unnested_hlo, Thunk::ThunkInfo(),
5599 {}, &ir_arrays));
5600
5601 HloComputation* fused_computation = nullptr;
5602 if (fusion) {
5603 TF_ASSIGN_OR_RETURN(fused_computation, GetOrCreateSubComputationFromRegion(
5604 &fusion.region(),
5605 /*is_fusion=*/true));
5606 }
5607
5608 // Group output instructions. Each group will be executed in parallel.
5609 std::vector<std::vector<int>> instr_index_groups =
5610 DivideOutputInstructionsIntoGroups(fused_computation, num_reduces);
5611
5612 VLOG(2) << StrCat("Generate in ", instr_index_groups.size(), " groups for ",
5613 MlirToString(unnested_hlo));
5614
5615 absl::optional<GpuElementalIrEmitter> elemental_emitter;
5616 absl::optional<FusedIrEmitter> optional_fused_emitter;
5617 FusedIrEmitter* fused_emitter = nullptr;
5618
5619 absl::Span<const llvm_ir::IrArray> operand_ir_arrays;
5620 absl::Span<const llvm_ir::IrArray> result_ir_arrays;
5621 if (fusion) {
5622 elemental_emitter.emplace(hlo_module_config_,
5623 ir_emitter_context_->llvm_module(), &b_,
5624 GetNestedComputer());
5625 optional_fused_emitter.emplace(&*elemental_emitter);
5626 fused_emitter = &*optional_fused_emitter;
5627
5628 CHECK_LT(fused_computation->num_parameters(), ir_arrays.size());
5629 for (int i = 0; i < fused_computation->num_parameters(); i++) {
5630 auto ir_array = ir_arrays[i];
5631 auto fused_operand = fused_computation->parameter_instruction(i);
5632 fused_emitter->BindGenerator(
5633 fused_operand, [this, ir_array,
5634 fused_operand](const llvm_ir::IrArray::Index& index) {
5635 return ir_array.EmitReadArrayElement(index, &b_,
5636 fused_operand->name());
5637 });
5638 }
5639 result_ir_arrays = absl::MakeSpan(ir_arrays).subspan(
5640 fused_computation->num_parameters(), num_reduces);
5641 } else {
5642 CHECK_EQ(3, ir_arrays.size());
5643 operand_ir_arrays = absl::MakeSpan(ir_arrays).subspan(0, 2);
5644 result_ir_arrays = absl::MakeSpan(ir_arrays).subspan(2);
5645 }
5646
5647 KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
5648 for (size_t i = 0; i < instr_index_groups.size(); ++i) {
5649 // Create a new ReductionCodegenInfo instance as it contains states for
5650 // code generation per reduction group. For now, let's always use the very
5651 // first reduce as representative to construct ReductionCodegenInfo, since
5652 // all the reductions are required to have the same shape and layout as
5653 // verified by `IsFusedReductionOutputConsistent()`. We can loosen the
5654 // constraint later when the needs arise.
5655 ReductionCodegenInfo reduction_info =
5656 ComputeReductionCodegenInfo(unnested_hlo, first_reduce);
5657 auto emit_reduction_func = [&] {
5658 EmitIRForReduction(unnested_hlo, instr_index_groups[i], fused_computation,
5659 fused_emitter, operand_ir_arrays, result_ir_arrays,
5660 &reduction_info, input_shape);
5661 };
5662 // Use raw block_id_y to select the i-th parallel reduction to run. Using
5663 // block_id_y instead of block_id_x simplifies the index calculation
5664 // for reduction code generation as the block_id_y is orthogonal to
5665 // the indices used within the reductions.
5666 llvm::CallInst* raw_block_id_y = gpu::EmitCallToTargetIntrinsic(
5667 gpu::TargetIntrinsicID::kBlockIdy, {}, {}, &b_);
5668 llvm_ir::AddRangeMetadata(0, instr_index_groups.size(),
5669 llvm::cast<llvm::Instruction>(raw_block_id_y));
5670 llvm::Value* guarding_cond =
5671 b_.CreateICmpEQ(raw_block_id_y, b_.getInt32(i));
5672 ksl.If(StrCat("reduce-group-", i), guarding_cond, emit_reduction_func);
5673 }
5674 ReductionCodegenInfo reduction_info =
5675 ComputeReductionCodegenInfo(unnested_hlo, first_reduce);
5676 const KernelMappingScheme& mapping_scheme =
5677 reduction_info.GetKernelMappingScheme();
5678 // block_y_count is set to instr_index_groups.size(), so that each reduction
5679 // group can be run in parallel by a different BlockIdy.
5680 LaunchDimensions launch_dimensions(
5681 {/*x=*/mapping_scheme.GetNumberOfBlocks(),
5682 /*y=*/static_cast<int64>(instr_index_groups.size()),
5683 /*z=*/1},
5684 {/*x=*/mapping_scheme.GetThreadsPerBlock(), /*y=*/1, /*z=*/1});
5685 VLOG(3) << "Launch dimensions of "
5686 << mlir::GetNameFromLoc(unnested_hlo->getLoc())
5687 << ": number of blocks: " << mapping_scheme.GetNumberOfBlocks()
5688 << " - threads per block: " << mapping_scheme.GetThreadsPerBlock();
5689 UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
5690 ir_emitter_context_->llvm_module());
5691
5692 thunks.push_back(std::move(kernel_thunk));
5693 std::unique_ptr<SequentialThunk> sequential_thunk =
5694 absl::make_unique<SequentialThunk>(mlir_input.thunk_info,
5695 std::move(thunks));
5696 AddThunkToThunkSequence(std::move(sequential_thunk));
5697
5698 return Status::OK();
5699 }
5700
5701 // Emits code for slices based on the below structure. An if statement with
5702 // a guarding condition is generated for each ROOT slice.
5703 //
5704 // Pseudo code:
5705 //
5706 // Compute values of slice input operands
5707 //
5708 // Compute guarding_cond0
5709 // if (guarding_cond0) {
5710 // Write to output of slice0
5711 // }
5712 //
5713 // Compute guarding_cond1
5714 // if (guarding_cond1) {
5715 // Write to output of slice1
5716 // }
5717 //
EmitElementForInputFusibleSlices(mlir::lmhlo::FusionOp fusion,absl::Span<const llvm_ir::IrArray> ir_arrays,const llvm_ir::IrArray::Index & index)5718 Status IrEmitterUnnested::EmitElementForInputFusibleSlices(
5719 mlir::lmhlo::FusionOp fusion, absl::Span<const llvm_ir::IrArray> ir_arrays,
5720 const llvm_ir::IrArray::Index& index) {
5721 VLOG(10) << "Emitting slice input fusion for " << MlirToString(fusion);
5722
5723 TF_ASSIGN_OR_RETURN(const HloComputation* fused_computation,
5724 GetOrCreateSubComputationFromRegion(&fusion.region(),
5725 /*is_fusion=*/true));
5726
5727 HloInstruction* slice_or_tuple = fused_computation->root_instruction();
5728 auto slice_instructions = [&]() -> absl::Span<HloInstruction* const> {
5729 if (slice_or_tuple->opcode() == HloOpcode::kSlice) {
5730 return absl::Span<HloInstruction* const>(&slice_or_tuple, 1);
5731 }
5732 CHECK_EQ(slice_or_tuple->opcode(), HloOpcode::kTuple);
5733 return slice_or_tuple->operands();
5734 }();
5735
5736 // Emit input operand values of slices.
5737 std::vector<llvm::Value*> input_ir_values;
5738 GpuElementalIrEmitter elem_emitter(hlo_module_config_, module_, &b_,
5739 GetNestedComputer());
5740 FusedIrEmitter fused_emitter(&elem_emitter);
5741 for (int i = 0; i < fused_computation->num_parameters(); i++) {
5742 fused_emitter.BindGenerator(
5743 fused_computation->parameter_instruction(i),
5744 [this, &ir_arrays, i](llvm_ir::IrArray::Index index) {
5745 return ir_arrays[i].EmitReadArrayElement(index, &b_);
5746 });
5747 }
5748 for (const HloInstruction* slice : slice_instructions) {
5749 auto input_generator = *fused_emitter.GetGenerator(slice->operand(0));
5750 input_ir_values.push_back(input_generator(index).ValueOrDie());
5751 }
5752
5753 // Emit for slice_instructions.
5754 KernelSupportLibrary ksl(&b_, llvm_ir::UnrollMode::kDefaultUnroll);
5755 for (int64 i = 0; i < slice_instructions.size(); ++i) {
5756 HloInstruction* slice = slice_instructions[i];
5757
5758 // guarding_cond := index >= start && index < limit, for each dim.
5759 std::vector<llvm::Value*> index_within_ranges;
5760 for (size_t dim = 0; dim < slice->slice_starts().size(); ++dim) {
5761 CHECK_EQ(slice->slice_strides(dim), 1);
5762 auto larger_or_equal_than_start = b_.CreateICmpSGE(
5763 index.multidim()[dim],
5764 index.GetConstantWithIndexType(slice->slice_starts(dim)));
5765 llvm::Value* smaller_than_limit = b_.CreateICmpSLT(
5766 index.multidim()[dim],
5767 index.GetConstantWithIndexType(slice->slice_limits(dim)));
5768 llvm::Value* within_range =
5769 b_.CreateAnd(larger_or_equal_than_start, smaller_than_limit);
5770 index_within_ranges.push_back(within_range);
5771 }
5772 llvm::Value* guarding_cond = b_.CreateAnd(index_within_ranges);
5773
5774 auto emit_slice_elem_func = [&] {
5775 const std::vector<llvm::Value*>& src_multidim = index.multidim();
5776 std::vector<llvm::Value*> dst_multidim(src_multidim.size());
5777 for (size_t dim = 0; dim < src_multidim.size(); ++dim) {
5778 dst_multidim[dim] =
5779 Sub(src_multidim[dim],
5780 index.GetConstantWithIndexType(slice->slice_starts(dim)));
5781 }
5782 llvm_ir::IrArray src_ir_array =
5783 ir_arrays[fused_computation->num_parameters() + i];
5784 IrArray::Index slice_dst_index(dst_multidim, slice->shape(),
5785 index.GetType());
5786 src_ir_array.EmitWriteArrayElement(slice_dst_index, input_ir_values[i],
5787 &b_);
5788 };
5789
5790 ksl.If(StrCat("slice", i), guarding_cond, emit_slice_elem_func);
5791 }
5792 return Status::OK();
5793 }
5794
EmitInputFusibleNonStridedSlices(MlirEmitterInput mlir_input)5795 Status IrEmitterUnnested::EmitInputFusibleNonStridedSlices(
5796 MlirEmitterInput mlir_input) {
5797 auto fusion = mlir::cast<mlir::lmhlo::FusionOp>(mlir_input.op);
5798
5799 constexpr int unroll_factor = 1;
5800
5801 std::vector<llvm_ir::IrArray> ir_arrays;
5802 TF_ASSIGN_OR_RETURN(
5803 auto kernel_thunk,
5804 BuildKernelThunkForMlir(fusion, mlir_input.thunk_info,
5805 mlir_input.extra_slice, &ir_arrays));
5806
5807 TF_ASSIGN_OR_RETURN(Shape element_shape,
5808 GetConsistentInputShapeForRootSlices(fusion));
5809 LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
5810 element_shape, ir_emitter_context_->gpu_device_info(), unroll_factor);
5811 UpdateLaunchDimensions(launch_dimensions, kernel_thunk.get(),
5812 ir_emitter_context_->llvm_module());
5813
5814 Status emit_status =
5815 ParallelLoopEmitter(
5816 [&](const llvm_ir::IrArray::Index index) -> Status {
5817 return EmitElementForInputFusibleSlices(fusion, ir_arrays, index);
5818 },
5819 element_shape, launch_dimensions, &b_)
5820 .EmitLoop(IrName(mlir::GetNameFromLoc(fusion.getLoc())),
5821 GetIndexTypeForKernelFromMlir(
5822 fusion, launch_dimensions.launch_bound(), &b_));
5823
5824 thunk_sequence_.emplace_back(std::move(kernel_thunk));
5825
5826 return emit_status;
5827 }
5828
GetThunkInfo(const HloInstruction * hlo) const5829 Thunk::ThunkInfo IrEmitterUnnested::GetThunkInfo(
5830 const HloInstruction* hlo) const {
5831 auto info = ThunkEmitter::EmissionContext::GetThunkInfo(hlo);
5832 if (const auto* index_map = ir_emitter_context_->profile_index_map()) {
5833 info.profile_index.emplace(
5834 static_cast<int64>(index_map->GetProfileIndexFor(*hlo)));
5835 }
5836 return info;
5837 }
5838
EmitOp(MlirEmitterInput mlir_input)5839 Status IrEmitterUnnested::EmitOp(MlirEmitterInput mlir_input) {
5840 if (mlir::isa<mlir::lmhlo::SortOp>(mlir_input.op)) {
5841 return EmitSortFromMlir(mlir_input);
5842 }
5843 if (mlir::isa<mlir::lmhlo::CollectivePermuteOp>(mlir_input.op)) {
5844 return EmitCollectivePermuteFromMlir(mlir_input);
5845 }
5846 LOG(FATAL)
5847 << "This function is for test only, and the op is not implemented: "
5848 << MlirToString(mlir_input.op);
5849 }
5850
SetOperation(mlir::Operation * op)5851 void MlirEmitterContext::SetOperation(mlir::Operation* op) {
5852 this->name = mlir::GetNameFromLoc(op->getLoc());
5853
5854 auto operands = GetHloOperands(op);
5855 auto outputs = GetHloOutputs(op);
5856 for (auto operand : operands) {
5857 operand_shapes.push_back(TypeToShape(operand.getType()));
5858 }
5859 for (auto output : outputs) {
5860 output_shapes.push_back(TypeToShape(output.getType()));
5861 }
5862 }
5863
HandleBitcast(HloInstruction * bitcast)5864 Status IrEmitterUnnested::HandleBitcast(HloInstruction* bitcast) {
5865 TF_ASSIGN_OR_RETURN(auto input, GetMlirEmitterInput(bitcast));
5866 DCHECK_EQ(nullptr, input.op);
5867 return Status::OK();
5868 }
5869
5870 } // namespace gpu
5871 } // namespace xla
5872