1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
17
18 #include <algorithm>
19 #include <array>
20 #include <vector>
21
22 #include "llvm/IR/IntrinsicsNVPTX.h"
23 #include "llvm/IR/Module.h"
24 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
25 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
26 #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
27 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
28 #include "tensorflow/compiler/xla/layout_util.h"
29 #include "tensorflow/compiler/xla/service/gpu/target_util.h"
30 #include "tensorflow/compiler/xla/service/hlo_computation.h"
31 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
32 #include "tensorflow/compiler/xla/service/hlo_module.h"
33 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
34 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
35 #include "tensorflow/compiler/xla/shape_util.h"
36 #include "tensorflow/compiler/xla/util.h"
37 #include "tensorflow/compiler/xla/window_util.h"
38 #include "tensorflow/compiler/xla/xla_data.pb.h"
39 #include "tensorflow/core/platform/logging.h"
40 #include "tensorflow/core/platform/protobuf.h"
41 #include "tensorflow/stream_executor/device_description.h"
42
43 namespace xla {
44 namespace gpu {
45
46 namespace {
47
48 // Return whether the given shape is rank 2 excluding the batch dimensions.
IsRank2(const Shape & shape,int64 batch_dimensions_size)49 bool IsRank2(const Shape& shape, int64 batch_dimensions_size) {
50 return shape.rank() == batch_dimensions_size + 2;
51 }
52
53 // In a gemm operation where output = lhs * rhs, check whether the given shapes
54 // are valid for the operation.
AreValidGemmShapes(const Shape & lhs_shape,const Shape & rhs_shape,const Shape & output_shape,int64 batch_dimensions_size)55 bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
56 const Shape& output_shape,
57 int64 batch_dimensions_size) {
58 // The inputs and the output must
59 // 1) be matrices with no padding and a non-zero number of elements,
60 // 2) have an allowed element type.
61 PrimitiveType output_primitive_type = output_shape.element_type();
62 bool type_is_allowed =
63 (output_primitive_type == F16 || output_primitive_type == F32 ||
64 output_primitive_type == F64 || output_primitive_type == C64 ||
65 output_primitive_type == C128);
66 return type_is_allowed && IsRank2(lhs_shape, batch_dimensions_size) &&
67 IsRank2(rhs_shape, batch_dimensions_size) &&
68 IsRank2(output_shape, batch_dimensions_size) &&
69 !ShapeUtil::IsZeroElementArray(lhs_shape) &&
70 !ShapeUtil::IsZeroElementArray(rhs_shape);
71 }
72
73 // Given a shape and a group of contiguous dimensions in the shape, returns
74 // a tuple of three values (major, middle, minor), where major is the size of
75 // the dimensions more major then the given dimensions, minor is the size of
76 // dimensions more minor then the given dimensions, and middle is the size of
77 // the given dimensions.
PartitionShapeByMiddleDimensions(const Shape & shape,absl::Span<const int64> dims_middle)78 std::array<int64, 3> PartitionShapeByMiddleDimensions(
79 const Shape& shape, absl::Span<const int64> dims_middle) {
80 CHECK(LayoutUtil::AreDimensionsConsecutive(shape.layout(), dims_middle));
81 std::array<int64, 3> values = {1, 1, 1};
82 enum Segment { kMajor = 0, kMiddle = 1, kMinor = 2 };
83 Segment cur_segment = kMinor;
84
85 for (int64 cur_dim : LayoutUtil::MinorToMajor(shape)) {
86 if (cur_segment != kMajor) {
87 // Handle change of segments.
88 bool cur_dim_in_middle = absl::c_linear_search(dims_middle, cur_dim);
89 if (cur_segment == kMinor) {
90 if (cur_dim_in_middle) {
91 cur_segment = kMiddle;
92 }
93 } else if (cur_segment == kMiddle) {
94 if (!cur_dim_in_middle) {
95 cur_segment = kMajor;
96 }
97 }
98 }
99 values[cur_segment] *= shape.dimensions(cur_dim);
100 }
101 return values;
102 }
103
104 } // namespace
105
IsMatrixMultiplication(const HloInstruction & dot)106 bool IsMatrixMultiplication(const HloInstruction& dot) {
107 if (dot.opcode() != HloOpcode::kDot) {
108 return false;
109 }
110 const Shape& lhs_shape = dot.operand(0)->shape();
111 const Shape& rhs_shape = dot.operand(1)->shape();
112 const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers();
113
114 // If gemm can accept the operand shapes, use it rather than a custom
115 // kernel.
116 if (AreValidGemmShapes(lhs_shape, rhs_shape, dot.shape(),
117 dim_numbers.lhs_batch_dimensions_size())) {
118 // The size of the reduction dimension should match. The shape inference
119 // guarantees this invariant, so the check here is for programming
120 // errors.
121 CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)),
122 rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0)));
123 return true;
124 }
125 return false;
126 }
127
IsCublasGemm(const HloInstruction & hlo)128 bool IsCublasGemm(const HloInstruction& hlo) {
129 return hlo.opcode() == HloOpcode::kCustomCall &&
130 hlo.custom_call_target() == kGemmCallTarget;
131 }
132
GetReductionTiling(const ReductionDimensions & reduction_dimensions,int smallest_input_dtype_bits,absl::optional<CudaComputeCapability> cuda_compute_capability)133 std::array<int64, 3> GetReductionTiling(
134 const ReductionDimensions& reduction_dimensions,
135 int smallest_input_dtype_bits,
136 absl::optional<CudaComputeCapability> cuda_compute_capability) {
137 if (reduction_dimensions.is_row_reduction) {
138 int64 tile_z = std::min(reduction_dimensions.dimensions[0], int64{8});
139 if (reduction_dimensions.dimensions[1] == 1) {
140 CHECK_EQ(reduction_dimensions.dimensions[0], 1);
141 return {tile_z, 1, 16};
142 }
143 if (reduction_dimensions.dimensions[2] % (kWarpSize * kWarpSize * 64) ==
144 0) {
145 return {tile_z, 1, 64};
146 }
147 int cc_major = 0;
148 if (cuda_compute_capability) {
149 cc_major = cuda_compute_capability->cc_major;
150 }
151 int unroll_x = 8;
152 if (cc_major >= 6 && smallest_input_dtype_bits == 16) {
153 unroll_x = 16;
154 } else if (cc_major >= 6 && smallest_input_dtype_bits == 8) {
155 unroll_x = 64;
156 }
157 return {tile_z, 1, unroll_x};
158 }
159
160 // Column reduction.
161 return {1, 128, 1};
162 }
163
164 const char* const kCudnnBatchNormForwardInferenceCallTarget =
165 "__cudnn$batchNormalizationForwardInference";
166 const char* const kCudnnBatchNormForwardTrainingCallTarget =
167 "__cudnn$batchNormalizationForwardTraining";
168 const char* const kCudnnBatchNormBackwardCallTarget =
169 "__cudnn$batchNormalizationBackward";
170
IsCustomCallToDnnBatchNorm(const HloInstruction & hlo)171 bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo) {
172 if (hlo.opcode() != HloOpcode::kCustomCall) {
173 return false;
174 }
175 const auto& target = hlo.custom_call_target();
176 return target == kCudnnBatchNormForwardInferenceCallTarget ||
177 target == kCudnnBatchNormForwardTrainingCallTarget ||
178 target == kCudnnBatchNormBackwardCallTarget;
179 }
180
181 const char* const kGemmCallTarget = "__cublas$gemm";
182 const char* const kCudnnConvForwardCallTarget = "__cudnn$convForward";
183 const char* const kCudnnConvBackwardInputCallTarget =
184 "__cudnn$convBackwardInput";
185 const char* const kCudnnConvBackwardFilterCallTarget =
186 "__cudnn$convBackwardFilter";
187 const char* const kCudnnConvBiasActivationForwardCallTarget =
188 "__cudnn$convBiasActivationForward";
189
IsCustomCallToDnnConvolution(const HloInstruction & hlo)190 bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) {
191 if (hlo.opcode() != HloOpcode::kCustomCall) {
192 return false;
193 }
194 const auto& target = hlo.custom_call_target();
195 return target == kCudnnConvForwardCallTarget ||
196 target == kCudnnConvBackwardInputCallTarget ||
197 target == kCudnnConvBackwardFilterCallTarget ||
198 target == kCudnnConvBiasActivationForwardCallTarget;
199 }
200
201 const char* const kCusolverCholeskyCallTarget = "__cusolver$cholesky";
202
IsCustomCallToCusolver(const HloInstruction & hlo)203 bool IsCustomCallToCusolver(const HloInstruction& hlo) {
204 if (hlo.opcode() != HloOpcode::kCustomCall) {
205 return false;
206 }
207 const auto& target = hlo.custom_call_target();
208 return target == kCusolverCholeskyCallTarget;
209 }
210
ImplementedAsLibraryCall(const HloInstruction & hlo)211 bool ImplementedAsLibraryCall(const HloInstruction& hlo) {
212 return IsCublasGemm(hlo) || IsCustomCallToDnnBatchNorm(hlo) ||
213 IsCustomCallToDnnConvolution(hlo);
214 }
215
GetReductionKindAndContiguousComponentsImpl(const Shape & input_shape,absl::Span<const int64> dims_to_reduce)216 static ReductionDimensions GetReductionKindAndContiguousComponentsImpl(
217 const Shape& input_shape, absl::Span<const int64> dims_to_reduce) {
218 DimensionVector dims_to_keep;
219 for (int64 dim = 0; dim < input_shape.rank(); ++dim) {
220 if (!absl::c_linear_search(dims_to_reduce, dim)) {
221 dims_to_keep.push_back(dim);
222 }
223 }
224
225 if (dims_to_keep.empty()) {
226 return {/*is_row_reduction=*/true,
227 {1, 1, ShapeUtil::ElementsIn(input_shape)}};
228 }
229
230 if (LayoutUtil::AreDimensionsConsecutive(input_shape.layout(),
231 dims_to_keep)) {
232 std::array<int64, 3> shape_partition =
233 PartitionShapeByMiddleDimensions(input_shape, dims_to_keep);
234 if (shape_partition[1] == 1) {
235 return {/*is_row_reduction=*/true,
236 {1, 1, shape_partition[0] * shape_partition[2]}};
237 }
238 if (shape_partition[2] == 1) {
239 return {/*is_row_reduction=*/false,
240 {1, shape_partition[0], shape_partition[1]}};
241 }
242 return {/*is_row_reduction=*/true, shape_partition};
243 }
244
245 std::array<int64, 3> shape_partition =
246 PartitionShapeByMiddleDimensions(input_shape, dims_to_reduce);
247
248 if (shape_partition[2] == 1) {
249 return {/*is_row_reduction=*/true,
250 {1, shape_partition[0], shape_partition[1]}};
251 }
252 return {/*is_row_reduction=*/false, shape_partition};
253 }
254
IsReductionFromOrToContiguousDimensions(const HloInstruction & reduce)255 bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce) {
256 if (HloOpcode::kReduce != reduce.opcode()) {
257 return false;
258 }
259
260 // TODO(b/129698548): Remove this check after fixing the bug.
261 if (reduce.shape().element_type() == C128) {
262 return false;
263 }
264
265 const HloInstruction* input = reduce.operand(0);
266 std::vector<int64> dims_to_keep;
267 for (int64 dim = 0; dim < input->shape().dimensions().size(); ++dim) {
268 if (!absl::c_linear_search(reduce.dimensions(), dim)) {
269 dims_to_keep.push_back(dim);
270 }
271 }
272
273 // We support fast codegen for three cases:
274 // 1) Row reduction: (K, R)
275 // 2) Column reduction: (K, R, K)
276 // 3) "Batched" row reduction: (R, K, R)
277 if (!LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
278 dims_to_keep) &&
279 !LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
280 reduce.dimensions())) {
281 return false;
282 }
283
284 ReductionDimensions reduction_dimensions =
285 GetReductionKindAndContiguousComponents(reduce);
286
287 if (reduction_dimensions.is_row_reduction) {
288 // For row reduction, the tile block is 1 x tile_size_x, and we are reducing
289 // along tile_size_x which needs to be large enough to make the tiling
290 // implementation efficient.
291 return reduction_dimensions.dimensions[2] >= kWarpSize;
292 }
293
294 // For column reduction, the tile block is tile_size_y x tile_size_x, and we
295 // are reducing along tile_size_y. Only tile_size_y needs to be
296 // large enough to make the tiling implementation efficient.
297 return reduction_dimensions.dimensions[1] >= kWarpSize;
298 }
299
IsReductionFromOrToContiguousDimensions(mlir::Operation * reduce)300 bool IsReductionFromOrToContiguousDimensions(mlir::Operation* reduce) {
301 if (!mlir::isa<mlir::lmhlo::ReduceOp>(reduce) &&
302 !mlir::isa<mlir::mhlo::ReduceOp>(reduce)) {
303 return false;
304 }
305 std::vector<mlir::Value> results = GetHloOutputs(reduce);
306 CHECK_EQ(1, results.size());
307
308 auto c128_type =
309 mlir::ComplexType::get(mlir::FloatType::getF64(reduce->getContext()));
310
311 // TODO(b/129698548): Remove this check after fixing the bug.
312 if (results[0].getType().cast<mlir::ShapedType>().getElementType() ==
313 c128_type) {
314 return false;
315 }
316
317 mlir::Value input = reduce->getOperand(0);
318 Shape operand_shape = TypeToShape(input.getType());
319 if (auto tensor_type = input.getType().dyn_cast<mlir::TensorType>()) {
320 if (auto attr = mlir::GetLayoutFromMlirHlo(input.getDefiningOp())) {
321 std::vector<int64> minor_to_major;
322 absl::c_transform(
323 attr, std::back_inserter(minor_to_major),
324 std::function<int64(const llvm::APInt&)>(&llvm::APInt::getZExtValue));
325 *operand_shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
326 }
327 }
328
329 std::vector<int64> dimensions;
330 {
331 auto attr = reduce->getAttrOfType<mlir::DenseIntElementsAttr>("dimensions");
332 CHECK(attr);
333 absl::c_transform(
334 attr, std::back_inserter(dimensions),
335 std::function<int64(const llvm::APInt&)>(&llvm::APInt::getZExtValue));
336 }
337
338 std::vector<int64> dims_to_keep;
339 for (int64 dim = 0; dim < operand_shape.dimensions().size(); ++dim) {
340 if (!absl::c_linear_search(dimensions, dim)) {
341 dims_to_keep.push_back(dim);
342 }
343 }
344
345 // We support fast codegen for three cases:
346 // 1) Row reduction: (K, R)
347 // 2) Column reduction: (K, R, K)
348 // 3) "Batched" row reduction: (R, K, R)
349 if (!LayoutUtil::AreDimensionsConsecutive(operand_shape.layout(),
350 dims_to_keep) &&
351 !LayoutUtil::AreDimensionsConsecutive(operand_shape.layout(),
352 dimensions)) {
353 return false;
354 }
355
356 ReductionDimensions reduction_dimensions =
357 GetReductionKindAndContiguousComponentsImpl(operand_shape, dimensions);
358
359 if (reduction_dimensions.is_row_reduction) {
360 // For row reduction, the tile block is 1 x tile_size_x, and we are reducing
361 // along tile_size_x which needs to be large enough to make the tiling
362 // implementation efficient.
363 return reduction_dimensions.dimensions[2] >= kWarpSize;
364 }
365
366 // For column reduction, the tile block is tile_size_y x tile_size_x, and we
367 // are reducing along tile_size_y. Only tile_size_y needs to be
368 // large enough to make the tiling implementation efficient.
369 return reduction_dimensions.dimensions[1] >= kWarpSize;
370 }
371
IsInputFusibleSlices(mlir::Operation * unnested_hlo,bool verify_no_strides)372 bool IsInputFusibleSlices(mlir::Operation* unnested_hlo,
373 bool verify_no_strides) {
374 auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(unnested_hlo);
375 if (!fusion) {
376 return false;
377 }
378
379 auto is_non_strided = [](mlir::DenseIntElementsAttr strides) -> bool {
380 return absl::c_all_of(
381 strides, [](const llvm::APInt& stride) { return stride == 1; });
382 };
383
384 for (mlir::Value value : fusion.getFusionResults()) {
385 auto slice =
386 mlir::dyn_cast_or_null<mlir::mhlo::SliceOp>(value.getDefiningOp());
387 if (!slice) {
388 return false;
389 }
390 if (verify_no_strides && !is_non_strided(slice.strides())) {
391 return false;
392 }
393 }
394 return true;
395 }
396
GetReductionKindAndContiguousComponents(const HloInstruction & reduce)397 ReductionDimensions GetReductionKindAndContiguousComponents(
398 const HloInstruction& reduce) {
399 return GetReductionKindAndContiguousComponentsImpl(reduce.operand(0)->shape(),
400 reduce.dimensions());
401 }
402
GetReductionKindAndContiguousComponents(mlir::Operation * reduce)403 ReductionDimensions GetReductionKindAndContiguousComponents(
404 mlir::Operation* reduce) {
405 mlir::Value input = reduce->getOperand(0);
406 Shape operand_shape = TypeToShape(input.getType());
407 std::vector<int64> dimensions;
408 {
409 auto attr = reduce->getAttrOfType<mlir::DenseIntElementsAttr>("dimensions");
410 CHECK(attr);
411 absl::c_transform(
412 attr, std::back_inserter(dimensions),
413 std::function<int64(const llvm::APInt&)>(&llvm::APInt::getZExtValue));
414 }
415 return GetReductionKindAndContiguousComponentsImpl(operand_shape, dimensions);
416 }
417
418 // This emits a device-side call to
419 // "i32 vprintf(i8* fmt, arguments_type* arguments)" in the driver; see
420 // http://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/index.html#system-calls
EmitPrintf(absl::string_view fmt,absl::Span<llvm::Value * const> arguments,llvm::IRBuilder<> * builder)421 llvm::Value* EmitPrintf(absl::string_view fmt,
422 absl::Span<llvm::Value* const> arguments,
423 llvm::IRBuilder<>* builder) {
424 std::vector<llvm::Type*> argument_types;
425
426 // Variadic arguments implicit promotion [1] converts float to double,
427 // and bool/char/short are converted to int.
428 // [1] https://en.cppreference.com/w/cpp/language/variadic_arguments
429 auto requires_int32_promotion = [](llvm::Type* type) {
430 return type->isIntegerTy(/*BitWidth=*/1) ||
431 type->isIntegerTy(/*BitWidth=*/8) ||
432 type->isIntegerTy(/*BitWidth=*/16);
433 };
434 auto requires_double_promotion = [](llvm::Type* type) {
435 return type->isFloatingPointTy();
436 };
437
438 for (auto argument : arguments) {
439 llvm::Type* type = argument->getType();
440 if (requires_double_promotion(type)) {
441 argument_types.push_back(builder->getDoubleTy());
442 } else if (requires_int32_promotion(type)) {
443 argument_types.push_back(builder->getInt32Ty());
444 } else {
445 argument_types.push_back(type);
446 }
447 }
448 auto* arguments_type = llvm::StructType::create(argument_types);
449 llvm::Value* arguments_ptr = builder->CreateAlloca(arguments_type);
450 for (size_t i = 0; i < arguments.size(); ++i) {
451 llvm::Value* value = arguments[i];
452 llvm::Type* type = value->getType();
453 if (requires_double_promotion(type)) {
454 value = builder->CreateFPCast(value, builder->getDoubleTy());
455 } else if (requires_int32_promotion(type)) {
456 value = builder->CreateIntCast(value, builder->getInt32Ty(),
457 /*isSigned=*/true);
458 }
459 builder->CreateStore(
460 value, builder->CreateGEP(arguments_ptr, {builder->getInt64(0),
461 builder->getInt32(i)}));
462 }
463 llvm::Type* ptr_ty = builder->getInt8Ty()->getPointerTo();
464 return builder->CreateCall(
465 builder->GetInsertBlock()->getParent()->getParent()->getOrInsertFunction(
466 "vprintf",
467 llvm::FunctionType::get(builder->getInt32Ty(), {ptr_ty, ptr_ty},
468 /*isVarArg=*/false)),
469 {builder->CreateGlobalStringPtr(llvm_ir::AsStringRef(fmt)),
470 builder->CreatePointerCast(arguments_ptr, ptr_ty)});
471 }
472
473 // Helper function to emit call to AMDGPU shfl_down function.
EmitAMDGPUShflDown(llvm::Value * value,llvm::Value * offset,llvm::IRBuilder<> * b)474 llvm::Value* EmitAMDGPUShflDown(llvm::Value* value, llvm::Value* offset,
475 llvm::IRBuilder<>* b) {
476 llvm::Module* module = b->GetInsertBlock()->getModule();
477 CHECK_EQ(value->getType()->getPrimitiveSizeInBits(), 32);
478 auto* i32_ty = b->getInt32Ty();
479 llvm::FunctionCallee shfl_fn = module->getOrInsertFunction(
480 llvm_ir::AsStringRef("__ockl_readuplane_i32"),
481 llvm::FunctionType::get(/*Result=*/i32_ty, {i32_ty, i32_ty},
482 /*isVarArg=*/false));
483 // AMDGPU device function requires first argument as i32.
484 llvm::Value* result =
485 b->CreateCall(shfl_fn, {b->CreateBitCast(value, i32_ty), offset});
486 // AMDGPU device function always returns an i32 type.
487 return b->CreateBitCast(result, value->getType());
488 }
489
490 // Helper function to emit call to NVPTX shfl_down intrinsic.
EmitNVPTXShflDown(llvm::Value * value,llvm::Value * offset,llvm::IRBuilder<> * b)491 llvm::Value* EmitNVPTXShflDown(llvm::Value* value, llvm::Value* offset,
492 llvm::IRBuilder<>* b) {
493 llvm::Module* module = b->GetInsertBlock()->getModule();
494 llvm::Intrinsic::ID llvm_intrinsic_id;
495 CHECK_EQ(value->getType()->getPrimitiveSizeInBits(), 32);
496 if (value->getType()->isFloatTy()) {
497 llvm_intrinsic_id = llvm::Intrinsic::nvvm_shfl_sync_down_f32;
498 } else {
499 llvm_intrinsic_id = llvm::Intrinsic::nvvm_shfl_sync_down_i32;
500 }
501 llvm::Function* intrinsic =
502 llvm::Intrinsic::getDeclaration(module, llvm_intrinsic_id, {});
503 return b->CreateCall(
504 intrinsic, {b->getInt32(-1), value, offset, b->getInt32(kWarpSize - 1)});
505 }
506
EmitFullWarpShuffleDown(llvm::Value * value,llvm::Value * offset,llvm::IRBuilder<> * builder)507 llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
508 llvm::IRBuilder<>* builder) {
509 int bit_width = value->getType()->getPrimitiveSizeInBits();
510 llvm::Module* module = builder->GetInsertBlock()->getModule();
511 llvm::Triple target_triple = llvm::Triple(module->getTargetTriple());
512
513 // Special case for efficiency
514 if (value->getType()->isFloatTy() && bit_width == 32) {
515 if (target_triple.isNVPTX()) {
516 return EmitNVPTXShflDown(value, offset, builder);
517 } else if (target_triple.getArch() == llvm::Triple::amdgcn) {
518 return EmitAMDGPUShflDown(value, offset, builder);
519 } else {
520 LOG(FATAL) << "Invalid triple " << target_triple.str();
521 }
522 }
523
524 // We must split values wider than 32 bits as the "shfl" instruction operates
525 // on 32-bit values.
526 int num_segments = CeilOfRatio(bit_width, 32);
527 llvm::Value* x = builder->CreateBitCast(
528 builder->CreateZExt(
529 builder->CreateBitCast(value, builder->getIntNTy(bit_width)),
530 builder->getIntNTy(32 * num_segments)),
531 llvm::VectorType::get(builder->getInt32Ty(), num_segments, false));
532 for (int i = 0; i < num_segments; ++i) {
533 llvm::Value* insert_val;
534 if (target_triple.isNVPTX()) {
535 insert_val = EmitNVPTXShflDown(builder->CreateExtractElement(x, i),
536 offset, builder);
537 } else if (target_triple.getArch() == llvm::Triple::amdgcn) {
538 insert_val = EmitAMDGPUShflDown(builder->CreateExtractElement(x, i),
539 offset, builder);
540 } else {
541 LOG(FATAL) << "Invalid triple " << target_triple.str();
542 }
543 x = builder->CreateInsertElement(x, insert_val, i);
544 }
545 return builder->CreateBitCast(
546 builder->CreateTrunc(
547 builder->CreateBitCast(x, builder->getIntNTy(32 * num_segments)),
548 builder->getIntNTy(bit_width)),
549 value->getType());
550 }
551
GetCudnnConvKind(const HloCustomCallInstruction * instr)552 StatusOr<CudnnConvKind> GetCudnnConvKind(
553 const HloCustomCallInstruction* instr) {
554 absl::string_view target = instr->custom_call_target();
555 if (target == kCudnnConvForwardCallTarget) {
556 return CudnnConvKind::kForward;
557 }
558 if (target == kCudnnConvBackwardInputCallTarget) {
559 return CudnnConvKind::kBackwardInput;
560 }
561 if (target == kCudnnConvBackwardFilterCallTarget) {
562 return CudnnConvKind::kBackwardFilter;
563 }
564 if (target == kCudnnConvBiasActivationForwardCallTarget) {
565 return CudnnConvKind::kForwardActivation;
566 }
567 return InternalError("Unexpected call target: %s", target);
568 }
569
CudnnConvKindToString(CudnnConvKind kind)570 string CudnnConvKindToString(CudnnConvKind kind) {
571 switch (kind) {
572 case CudnnConvKind::kForward:
573 return "forward";
574 case CudnnConvKind::kBackwardFilter:
575 return "backward_filter";
576 case CudnnConvKind::kBackwardInput:
577 return "backward_input";
578 case CudnnConvKind::kForwardActivation:
579 return "forward with activation";
580 }
581 }
582
IsBlock0Thread0(llvm::IRBuilder<> * b)583 llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b) {
584 llvm::Value* is_thread0 = b->CreateICmpEQ(
585 b->getInt32(0),
586 EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b));
587
588 llvm::Value* is_block0 = b->CreateICmpEQ(
589 b->getInt32(0),
590 EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b));
591 return b->CreateAnd(is_thread0, is_block0);
592 }
593
IsFusedReductionOutputConsistent(const HloInstruction * inst,const HloInstruction * first_reduce)594 bool IsFusedReductionOutputConsistent(const HloInstruction* inst,
595 const HloInstruction* first_reduce) {
596 if (IsReductionFromOrToContiguousDimensions(*inst)) {
597 // Shapes, layouts and dimensions must be the same for all reduces
598 // inside of this fusion.
599 // TODO(tjoerg): Relax the shape constraint. The datatype does not matter.
600 return ShapeUtil::Equal(first_reduce->shape(), inst->shape()) &&
601 ShapeUtil::Equal(first_reduce->operand(0)->shape(),
602 inst->operand(0)->shape()) &&
603 ShapeUtil::Equal(first_reduce->operand(1)->shape(),
604 inst->operand(1)->shape()) &&
605 first_reduce->dimensions() == inst->dimensions();
606 }
607 return ShapeUtil::CompatibleIgnoringElementType(
608 first_reduce->operand(0)->shape(), inst->shape()) &&
609 LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(),
610 inst->shape().layout());
611 }
612
IsFusedReductionOutputConsistent(mlir::mhlo::ReduceOp inst,mlir::mhlo::ReduceOp first_reduce)613 bool IsFusedReductionOutputConsistent(mlir::mhlo::ReduceOp inst,
614 mlir::mhlo::ReduceOp first_reduce) {
615 CHECK_EQ(1, first_reduce.getNumResults());
616 Shape first_reduce_operand_shape =
617 TypeToShape(first_reduce.operands()[0].getType());
618 CHECK_EQ(1, inst.getNumResults());
619 auto inst_shape = TypeToShape(inst.getResult(0).getType());
620
621 if (IsReductionFromOrToContiguousDimensions(inst)) {
622 auto first_reduce_shape = TypeToShape(first_reduce.getResult(0).getType());
623 auto first_reduce_init_shape =
624 TypeToShape(first_reduce.init_values()[0].getType());
625
626 auto inst_operand_shape = TypeToShape(inst.operands()[0].getType());
627 auto inst_init_shape = TypeToShape(inst.init_values()[0].getType());
628
629 // Shapes, layouts and dimensions must be the same for all reduces
630 // inside of this fusion.
631 // TODO(tjoerg): Relax the shape constraint. The datatype does not matter.
632 if (!(ShapeUtil::Equal(first_reduce_shape, inst_shape) &&
633 ShapeUtil::Equal(first_reduce_operand_shape, inst_operand_shape) &&
634 ShapeUtil::Equal(first_reduce_init_shape, inst_init_shape) &&
635 absl::c_equal(first_reduce.dimensions(), inst.dimensions()))) {
636 return false;
637 }
638 } else {
639 if (!(ShapeUtil::CompatibleIgnoringElementType(first_reduce_operand_shape,
640 inst_shape) &&
641 LayoutUtil::Equal(first_reduce_operand_shape.layout(),
642 inst_shape.layout()))) {
643 return false;
644 }
645 }
646 return true;
647 }
648
649 // Given an LMHLO op, returns the operand index of the first output operand.
650 //
651 // Notice that an operand alised to an output isn't an output, even though in
652 // that case WritesMlirBuffer() returns true on that operand.
653 //
654 // An operand is !WritesMlirBuffer() || equals (aliases) to a later operand. An
655 // output is the opposite, being both WritesMlirBuffer() and does not equal to
656 // any later operand.
PartitionLmhloOperandsAndOutputs(mlir::Operation * op)657 int PartitionLmhloOperandsAndOutputs(mlir::Operation* op) {
658 CHECK(op->getDialect() == op->getContext()->getLoadedDialect("lmhlo"));
659
660 int i;
661 for (i = op->getOperands().size() - 1; i >= 0; i--) {
662 const bool aliased =
663 std::find(op->getOperands().begin() + i + 1, op->getOperands().end(),
664 op->getOperand(i)) != op->getOperands().end();
665 if (!WritesMlirBuffer(op, op->getOperand(i)) || aliased) {
666 break;
667 }
668 }
669 return i + 1;
670 }
671
GetHloOperands(mlir::Operation * op)672 std::vector<mlir::Value> GetHloOperands(mlir::Operation* op) {
673 if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
674 return ToStdVector(fusion.getInputBuffers());
675 }
676 if (op->getDialect() == op->getContext()->getLoadedDialect("lmhlo")) {
677 int output_start = PartitionLmhloOperandsAndOutputs(op);
678 std::vector<mlir::Value> operands;
679 operands.reserve(output_start);
680 for (int i = 0; i < output_start; i++) {
681 operands.push_back(op->getOperand(i));
682 }
683 return operands;
684 }
685 if (op->getDialect() == op->getContext()->getLoadedDialect("mhlo")) {
686 return std::vector<mlir::Value>(op->getOperands().begin(),
687 op->getOperands().end());
688 }
689 LOG(FATAL) << "Unexpected op: " << MlirToString(op);
690 }
691
GetHloOutputs(mlir::Operation * op)692 std::vector<mlir::Value> GetHloOutputs(mlir::Operation* op) {
693 if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
694 return ToStdVector(fusion.getOutputBuffers());
695 }
696 if (op->getDialect() == op->getContext()->getLoadedDialect("lmhlo")) {
697 int output_start = PartitionLmhloOperandsAndOutputs(op);
698 std::vector<mlir::Value> outputs;
699 for (int i = output_start; i < op->getNumOperands(); i++) {
700 outputs.push_back(op->getOperand(i));
701 }
702 return outputs;
703 }
704 if (op->getDialect() == op->getContext()->getLoadedDialect("mhlo")) {
705 return std::vector<mlir::Value>(op->getResults().begin(),
706 op->getResults().end());
707 }
708 LOG(FATAL) << "Unexpected op: " << MlirToString(op);
709 }
710
WritesMlirBuffer(mlir::Operation * op,mlir::Value operand)711 bool WritesMlirBuffer(mlir::Operation* op, mlir::Value operand) {
712 llvm::SmallVector<mlir::MemoryEffects::EffectInstance, 2> effects;
713 mlir::cast<mlir::MemoryEffectOpInterface>(op).getEffectsOnValue(operand,
714 effects);
715 return absl::c_any_of(
716 effects, [](const mlir::MemoryEffects::EffectInstance& instance) {
717 return mlir::isa<mlir::MemoryEffects::Write>(instance.getEffect());
718 });
719 }
720
GetMemRefSizeInBytes(mlir::MemRefType type)721 static int64_t GetMemRefSizeInBytes(mlir::MemRefType type) {
722 // For i1 memrefs, the underlying allocation is 8 bits.
723 if (type.getElementType().isInteger(/*width=*/1)) {
724 return type.getNumElements();
725 } else {
726 return type.getSizeInBits() / CHAR_BIT;
727 }
728 }
729
GetAllocationIndex(mlir::BlockArgument func_arg)730 static int64_t GetAllocationIndex(mlir::BlockArgument func_arg) {
731 auto func_op =
732 mlir::cast<mlir::FuncOp>(func_arg.getParentRegion()->getParentOp());
733 return func_op
734 .getArgAttrOfType<mlir::IntegerAttr>(func_arg.getArgNumber(),
735 "lmhlo.alloc")
736 .getValue()
737 .getSExtValue();
738 }
739
GetAllocationSliceForMlir(mlir::Value v,absl::Span<const BufferAllocation> allocations)740 StatusOr<BufferAllocation::Slice> GetAllocationSliceForMlir(
741 mlir::Value v, absl::Span<const BufferAllocation> allocations) {
742 int64 size = GetMemRefSizeInBytes(v.getType().cast<mlir::MemRefType>());
743
744 if (auto arg = v.dyn_cast<mlir::BlockArgument>()) {
745 return BufferAllocation::Slice(&allocations[GetAllocationIndex(arg)], 0,
746 size);
747 }
748
749 // We match the following patterns here:
750 // base := ViewOp(arg) | get_global_memref (global_memref)
751 // root := base | MemRefReinterpretCastOp(base)
752
753 if (mlir::Operation* op = v.getDefiningOp()) {
754 if (auto cast = mlir::dyn_cast<mlir::MemRefReinterpretCastOp>(op)) {
755 mlir::Value source = cast.getViewSource();
756 op = source.getDefiningOp();
757 if (!op) {
758 return Unimplemented("MemRefReinterpretCastOp has to wrap an op");
759 }
760 }
761 if (auto view = mlir::dyn_cast<mlir::ViewOp>(op)) {
762 return BufferAllocation::Slice(
763 &allocations[GetAllocationIndex(
764 view.source().cast<mlir::BlockArgument>())],
765 mlir::cast<mlir::ConstantOp>(view.byte_shift().getDefiningOp())
766 .value()
767 .cast<mlir::IntegerAttr>()
768 .getValue()
769 .getSExtValue(),
770 size);
771 } else if (auto get_global = mlir::dyn_cast<mlir::GetGlobalMemrefOp>(op)) {
772 auto module = get_global->getParentOfType<mlir::ModuleOp>();
773 auto global = mlir::cast<mlir::GlobalMemrefOp>(
774 module.lookupSymbol(get_global.name()));
775 int64_t index =
776 global->getAttrOfType<mlir::IntegerAttr>("lmhlo.alloc").getInt();
777 return BufferAllocation::Slice(&allocations[index], 0,
778 allocations[index].size());
779 }
780 return Unimplemented("MemRefReinterpretCastOp has to wrap a ViewOp");
781 }
782
783 return Unimplemented(
784 "Operand has to be in the form of ViewOp(arg) or "
785 "StaticMemRefCastOp(ViewOp(arg))");
786 }
787
CanEmitFusedDynamicUpdateSliceInPlaceForGpu(mlir::lmhlo::FusionOp fusion,absl::Span<const BufferAllocation> allocations)788 bool CanEmitFusedDynamicUpdateSliceInPlaceForGpu(
789 mlir::lmhlo::FusionOp fusion,
790 absl::Span<const BufferAllocation> allocations) {
791 auto results = fusion.getFusionResults();
792 if (results.size() != 1) {
793 return false;
794 }
795 auto dus = mlir::dyn_cast<mlir::mhlo::DynamicUpdateSliceOp>(
796 results[0].getDefiningOp());
797 if (!dus) {
798 return false;
799 }
800
801 auto output_buffers = fusion.getOutputBuffers();
802 CHECK_EQ(1, output_buffers.size());
803 auto parameter =
804 mlir::dyn_cast<mlir::TensorLoadOp>(dus.operand().getDefiningOp());
805
806 if (!parameter) {
807 return false;
808 }
809
810 auto maybe_lhs = GetAllocationSliceForMlir(parameter.memref(), allocations);
811 auto maybe_rhs = GetAllocationSliceForMlir(output_buffers[0], allocations);
812 return maybe_lhs.ok() && maybe_rhs.ok() && *maybe_lhs == *maybe_rhs;
813 }
814
815 } // namespace gpu
816 } // namespace xla
817