1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
17
18 #include <algorithm>
19 #include <array>
20 #include <vector>
21
22 #include "llvm/IR/IntrinsicsNVPTX.h"
23 #include "llvm/IR/Module.h"
24 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
25 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
26 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
27 #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
28 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
29 #include "tensorflow/compiler/xla/layout_util.h"
30 #include "tensorflow/compiler/xla/service/gpu/target_util.h"
31 #include "tensorflow/compiler/xla/service/hlo_computation.h"
32 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
33 #include "tensorflow/compiler/xla/service/hlo_module.h"
34 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
35 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/util.h"
38 #include "tensorflow/compiler/xla/window_util.h"
39 #include "tensorflow/compiler/xla/xla_data.pb.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/platform/protobuf.h"
42 #include "tensorflow/stream_executor/device_description.h"
43
44 namespace xla {
45 namespace gpu {
46
47 namespace {
48
49 // Return whether the given shape is rank 2 excluding the batch dimensions.
IsRank2(const Shape & shape,int64_t batch_dimensions_size)50 bool IsRank2(const Shape& shape, int64_t batch_dimensions_size) {
51 return shape.rank() == batch_dimensions_size + 2;
52 }
53
54 // Given a shape and a group of contiguous dimensions in the shape, returns
55 // a tuple of three values (major, middle, minor), where major is the size of
56 // the dimensions more major then the given dimensions, minor is the size of
57 // dimensions more minor then the given dimensions, and middle is the size of
58 // the given dimensions.
PartitionShapeByMiddleDimensions(const Shape & shape,absl::Span<const int64> dims_middle)59 std::array<int64, 3> PartitionShapeByMiddleDimensions(
60 const Shape& shape, absl::Span<const int64> dims_middle) {
61 CHECK(LayoutUtil::AreDimensionsConsecutive(shape.layout(), dims_middle));
62 std::array<int64, 3> values = {1, 1, 1};
63 enum Segment { kMajor = 0, kMiddle = 1, kMinor = 2 };
64 Segment cur_segment = kMinor;
65
66 for (int64_t cur_dim : LayoutUtil::MinorToMajor(shape)) {
67 if (cur_segment != kMajor) {
68 // Handle change of segments.
69 bool cur_dim_in_middle = absl::c_linear_search(dims_middle, cur_dim);
70 if (cur_segment == kMinor) {
71 if (cur_dim_in_middle) {
72 cur_segment = kMiddle;
73 }
74 } else if (cur_segment == kMiddle) {
75 if (!cur_dim_in_middle) {
76 cur_segment = kMajor;
77 }
78 }
79 }
80 values[cur_segment] *= shape.dimensions(cur_dim);
81 }
82 return values;
83 }
84
GetShapeFromTensorType(mlir::Value value)85 Shape GetShapeFromTensorType(mlir::Value value) {
86 constexpr char kDefaultLayoutAttrName[] = "minor_to_major";
87
88 mlir::Operation* op = value.getDefiningOp();
89 CHECK(op);
90 CHECK(value.getType().isa<mlir::TensorType>());
91 Shape shape = TypeToShape(value.getType());
92 if (auto attr = op->getAttrOfType<mlir::DenseIntElementsAttr>(
93 kDefaultLayoutAttrName)) {
94 std::vector<int64> minor_to_major;
95 absl::c_transform(
96 attr, std::back_inserter(minor_to_major),
97 std::function<int64(const llvm::APInt&)>(&llvm::APInt::getZExtValue));
98 *shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
99 } else {
100 *shape.mutable_layout() = LayoutUtil::MakeDescendingLayout(
101 value.getType().cast<mlir::ShapedType>().getShape().size());
102 }
103 return shape;
104 }
105
106 } // namespace
107
IsMatrixMultiplication(const HloInstruction & dot)108 bool IsMatrixMultiplication(const HloInstruction& dot) {
109 if (dot.opcode() != HloOpcode::kDot) {
110 return false;
111 }
112 const Shape& lhs_shape = dot.operand(0)->shape();
113 const Shape& rhs_shape = dot.operand(1)->shape();
114 const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers();
115
116 PrimitiveType output_primitive_type = dot.shape().element_type();
117 bool type_is_allowed =
118 (output_primitive_type == F16 || output_primitive_type == BF16 ||
119 output_primitive_type == F32 || output_primitive_type == F64 ||
120 output_primitive_type == C64 || output_primitive_type == C128) ||
121 (output_primitive_type == S32 && lhs_shape.element_type() == S8 &&
122 lhs_shape.element_type() == S8);
123 bool shapes_are_valid =
124 type_is_allowed &&
125 IsRank2(lhs_shape, dim_numbers.lhs_batch_dimensions_size()) &&
126 IsRank2(rhs_shape, dim_numbers.lhs_batch_dimensions_size()) &&
127 IsRank2(dot.shape(), dim_numbers.lhs_batch_dimensions_size()) &&
128 !ShapeUtil::IsZeroElementArray(lhs_shape) &&
129 !ShapeUtil::IsZeroElementArray(rhs_shape);
130
131 if (!shapes_are_valid) {
132 return false;
133 }
134
135 // The size of the reduction dimension should match. The shape inference
136 // guarantees this invariant, so the check here is for programming
137 // errors.
138 CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)),
139 rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0)));
140
141 return true;
142 }
143
IsCublasGemm(const HloInstruction & hlo)144 bool IsCublasGemm(const HloInstruction& hlo) {
145 return hlo.opcode() == HloOpcode::kCustomCall &&
146 hlo.custom_call_target() == kGemmCallTarget;
147 }
148
GetReductionTiling(const ReductionDimensions & reduction_dimensions,int smallest_input_dtype_bits,se::CudaComputeCapability cuda_compute_capability)149 std::array<int64, 3> GetReductionTiling(
150 const ReductionDimensions& reduction_dimensions,
151 int smallest_input_dtype_bits,
152 se::CudaComputeCapability cuda_compute_capability) {
153 if (reduction_dimensions.is_row_reduction) {
154 int64_t tile_z = std::min(reduction_dimensions.dimensions[0],
155 kBatchedReductionRaceFreeBound);
156 return {tile_z, 1, 64};
157 }
158
159 // Column reduction.
160 return {1, 128, 1};
161 }
162
163 const char* const kCudnnBatchNormForwardInferenceCallTarget =
164 "__cudnn$batchNormalizationForwardInference";
165 const char* const kCudnnBatchNormForwardTrainingCallTarget =
166 "__cudnn$batchNormalizationForwardTraining";
167 const char* const kCudnnBatchNormBackwardCallTarget =
168 "__cudnn$batchNormalizationBackward";
169
IsCustomCallToDnnBatchNorm(const HloInstruction & hlo)170 bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo) {
171 if (hlo.opcode() != HloOpcode::kCustomCall) {
172 return false;
173 }
174 const auto& target = hlo.custom_call_target();
175 return target == kCudnnBatchNormForwardInferenceCallTarget ||
176 target == kCudnnBatchNormForwardTrainingCallTarget ||
177 target == kCudnnBatchNormBackwardCallTarget;
178 }
179
180 const char* const kGemmCallTarget = "__cublas$gemm";
181 const char* const kCudnnConvForwardCallTarget = "__cudnn$convForward";
182 const char* const kCudnnConvBackwardInputCallTarget =
183 "__cudnn$convBackwardInput";
184 const char* const kCudnnConvBackwardFilterCallTarget =
185 "__cudnn$convBackwardFilter";
186 const char* const kCudnnConvBiasActivationForwardCallTarget =
187 "__cudnn$convBiasActivationForward";
188
IsCustomCallToDnnConvolution(const HloInstruction & hlo)189 bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) {
190 if (hlo.opcode() != HloOpcode::kCustomCall) {
191 return false;
192 }
193 const auto& target = hlo.custom_call_target();
194 return target == kCudnnConvForwardCallTarget ||
195 target == kCudnnConvBackwardInputCallTarget ||
196 target == kCudnnConvBackwardFilterCallTarget ||
197 target == kCudnnConvBiasActivationForwardCallTarget;
198 }
199
200 const char* const kCusolverCholeskyCallTarget = "__cusolver$cholesky";
201
IsCustomCallToCusolver(const HloInstruction & hlo)202 bool IsCustomCallToCusolver(const HloInstruction& hlo) {
203 if (hlo.opcode() != HloOpcode::kCustomCall) {
204 return false;
205 }
206 const auto& target = hlo.custom_call_target();
207 return target == kCusolverCholeskyCallTarget;
208 }
209
ImplementedAsLibraryCall(const HloInstruction & hlo)210 bool ImplementedAsLibraryCall(const HloInstruction& hlo) {
211 return IsCublasGemm(hlo) || IsCustomCallToDnnBatchNorm(hlo) ||
212 IsCustomCallToDnnConvolution(hlo);
213 }
214
GetReductionKindAndContiguousComponentsImpl(const Shape & input_shape,absl::Span<const int64> dims_to_reduce)215 static ReductionDimensions GetReductionKindAndContiguousComponentsImpl(
216 const Shape& input_shape, absl::Span<const int64> dims_to_reduce) {
217 DimensionVector dims_to_keep;
218 for (int64_t dim = 0; dim < input_shape.rank(); ++dim) {
219 if (!absl::c_linear_search(dims_to_reduce, dim)) {
220 dims_to_keep.push_back(dim);
221 }
222 }
223
224 if (dims_to_keep.empty()) {
225 return {/*is_row_reduction=*/true,
226 {1, 1, ShapeUtil::ElementsIn(input_shape)}};
227 }
228
229 if (LayoutUtil::AreDimensionsConsecutive(input_shape.layout(),
230 dims_to_keep)) {
231 std::array<int64, 3> shape_partition =
232 PartitionShapeByMiddleDimensions(input_shape, dims_to_keep);
233 if (shape_partition[1] == 1) {
234 return {/*is_row_reduction=*/true,
235 {1, 1, shape_partition[0] * shape_partition[2]}};
236 }
237 if (shape_partition[2] == 1) {
238 return {/*is_row_reduction=*/false,
239 {1, shape_partition[0], shape_partition[1]}};
240 }
241 return {/*is_row_reduction=*/true, shape_partition};
242 }
243
244 std::array<int64, 3> shape_partition =
245 PartitionShapeByMiddleDimensions(input_shape, dims_to_reduce);
246
247 if (shape_partition[2] == 1) {
248 return {/*is_row_reduction=*/true,
249 {1, shape_partition[0], shape_partition[1]}};
250 }
251 return {/*is_row_reduction=*/false, shape_partition};
252 }
253
IsUnnestedReductionFasterThanElemental(const ReductionDimensions & reduction_dimensions)254 static bool IsUnnestedReductionFasterThanElemental(
255 const ReductionDimensions& reduction_dimensions) {
256 if (reduction_dimensions.is_row_reduction) {
257 // For row reduction, the tile block is 1 x tile_size_x, and we are reducing
258 // along tile_size_x which needs to be large enough to make the tiling
259 // implementation efficient.
260 return reduction_dimensions.dimensions[2] >= kWarpSize;
261 }
262
263 // For column reduction, the tile block is tile_size_y x tile_size_x, and we
264 // are reducing along tile_size_y. Only tile_size_y needs to be
265 // large enough to make the tiling implementation efficient.
266 return reduction_dimensions.dimensions[1] >= kWarpSize;
267 }
268
IsReductionFromOrToContiguousDimensions(const HloInstruction & reduce)269 bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce) {
270 if (HloOpcode::kReduce != reduce.opcode()) {
271 return false;
272 }
273
274 // TODO(b/129698548): Remove this check after fixing the bug.
275 if (reduce.shape().element_type() == C128) {
276 return false;
277 }
278
279 const HloInstruction* input = reduce.operand(0);
280 std::vector<int64> dims_to_keep;
281 for (int64_t dim = 0; dim < input->shape().dimensions().size(); ++dim) {
282 if (!absl::c_linear_search(reduce.dimensions(), dim)) {
283 dims_to_keep.push_back(dim);
284 }
285 }
286
287 // We support fast codegen for three cases:
288 // 1) Row reduction: (K, R)
289 // 2) Column reduction: (K, R, K)
290 // 3) "Batched" row reduction: (R, K, R)
291 if (!LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
292 dims_to_keep) &&
293 !LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
294 reduce.dimensions())) {
295 return false;
296 }
297
298 return IsUnnestedReductionFasterThanElemental(
299 GetReductionKindAndContiguousComponents(reduce));
300 }
301
302 // Constructs the fusion layout analysis object by using a heuristic to infer
303 // the layout of a fusion internal value. In general, if the value is derived
304 // from a fusion parameter (which by definition has a layout) using elementwise
305 // operations, it will inherit the layout of that parameter. OTOH if the value
306 // if written to a fusion output, it will inherit the layout of that output.
307 // If the heuristic fails, the default layout will be inferred.
FusionLayoutAnalysis(mlir::lmhlo::FusionOp fusion_op)308 FusionLayoutAnalysis::FusionLayoutAnalysis(mlir::lmhlo::FusionOp fusion_op) {
309 VLOG(3) << "Analyzing \n" << MlirToString(fusion_op);
310 auto add_layout = [this](mlir::Value v, const Layout& layout) {
311 layouts_[v] = layout;
312 VLOG(3) << "===============\n";
313 VLOG(3) << "For value \n" << MlirToString(v.getDefiningOp());
314 VLOG(3) << "Layout = " << layout.ToString() << "\n";
315 VLOG(3) << "===============\n";
316 };
317
318 // Propagate layouts inside fusion region.
319 for (mlir::Operation& op : fusion_op.region().front().without_terminator()) {
320 if (auto load = mlir::dyn_cast<mlir::memref::TensorLoadOp>(op)) {
321 add_layout(load, GetShape(load.memref()).layout());
322 } else if (auto store = mlir::dyn_cast<mlir::memref::TensorStoreOp>(op)) {
323 // Propagate the stored memref layout to the value if it does not have a
324 // inferred layout already. This prefers load coalescing over stores.
325 if (layouts_.count(store.tensor()) == 0) {
326 add_layout(store.tensor(), GetShape(store.memref()).layout());
327 }
328 } else if (auto bitcast = mlir::dyn_cast<mlir::mhlo::BitcastOp>(op)) {
329 auto attr =
330 bitcast->getAttrOfType<mlir::DenseIntElementsAttr>("result_layout");
331 std::vector<int64> minor_to_major;
332 absl::c_transform(
333 attr, std::back_inserter(minor_to_major),
334 std::function<int64(const llvm::APInt&)>(&llvm::APInt::getZExtValue));
335 add_layout(bitcast, LayoutUtil::MakeLayout(minor_to_major));
336
337 attr =
338 bitcast->getAttrOfType<mlir::DenseIntElementsAttr>("source_layout");
339 minor_to_major.clear();
340 absl::c_transform(
341 attr, std::back_inserter(minor_to_major),
342 std::function<int64(const llvm::APInt&)>(&llvm::APInt::getZExtValue));
343 add_layout(bitcast.operand(), LayoutUtil::MakeLayout(minor_to_major));
344 } else {
345 HloOpcode opcode = *xla::MhloToHloOpcode(&op);
346 if (!HloInstruction::IsOpElementwise(opcode)) {
347 continue;
348 }
349 // If any operand has a layout, infer that layout for the result of the
350 // operation. If 2 operands have a conflicting layout, we still need to
351 // choose one of them, so we will arbitrarily choose the first one.
352 for (mlir::Value operand : op.getOperands()) {
353 auto it = layouts_.find(operand);
354 if (it != layouts_.end()) {
355 // Do not pass in a reference to an entry in the map when adding a new
356 // entry. The map may expand when adding, and the reference may become
357 // invalid. To avoid this, create a local copy of the layout.
358 const Layout operand_layout = it->second;
359 add_layout(op.getResult(0), operand_layout);
360 break;
361 }
362 }
363 }
364 }
365 }
366
GetShape(mlir::Value value) const367 Shape FusionLayoutAnalysis::GetShape(mlir::Value value) const {
368 Shape shape = TypeToShape(value.getType());
369 if (!value.getType().isa<mlir::MemRefType>()) {
370 auto it = layouts_.find(value);
371 if (it != layouts_.end()) {
372 *shape.mutable_layout() = it->second;
373 }
374 }
375 return shape;
376 }
377
IsReductionFromOrToContiguousDimensions(mlir::Operation * reduce,const FusionLayoutAnalysis & layout_analysis)378 bool IsReductionFromOrToContiguousDimensions(
379 mlir::Operation* reduce, const FusionLayoutAnalysis& layout_analysis) {
380 if (!mlir::isa<mlir::mhlo::ReduceOp>(reduce)) {
381 return false;
382 }
383 std::vector<mlir::Value> results = GetHloOutputs(reduce);
384 CHECK_EQ(1, results.size());
385
386 auto c128_type =
387 mlir::ComplexType::get(mlir::FloatType::getF64(reduce->getContext()));
388
389 // TODO(b/129698548): Remove this check after fixing the bug.
390 if (results[0].getType().cast<mlir::ShapedType>().getElementType() ==
391 c128_type) {
392 return false;
393 }
394
395 mlir::Value input = reduce->getOperand(0);
396 const Shape operand_shape = layout_analysis.GetShape(input);
397
398 // Enable this code to check mismatch between the inferred layout and what was
399 // there before. Based on actual runs, some mismatches are expected.
400 #if 0
401 Shape operand_shape_ir = GetShape(input);
402 if (auto tensor_type = input.getType().dyn_cast<mlir::TensorType>()) {
403 if (auto attr = mlir::GetLayoutFromMlirHlo(input.getDefiningOp())) {
404 std::vector<int64> minor_to_major;
405 absl::c_transform(
406 attr, std::back_inserter(minor_to_major),
407 std::function<int64(const llvm::APInt&)>(&llvm::APInt::getZExtValue));
408 *operand_shape_ir.mutable_layout() =
409 LayoutUtil::MakeLayout(minor_to_major);
410 }
411 }
412 bool match = ShapeUtil::Equal(operand_shape, operand_shape_ir);
413 llvm::errs() << "inferred shape = " << operand_shape.ToString(true) << "\n";
414 llvm::errs() << "Actual shape in IR = " << operand_shape_ir.ToString(true)
415 << "\n";
416 if (!match) {
417 llvm::errs() << "Unable to infer layout for reduce op operand(0)\n";
418 llvm::errs() << "\nreduce = \n";
419 reduce->dump();
420 llvm::errs() << "\nparent = \n";
421 reduce->getParentOp()->dump();
422 CHECK(0);
423 }
424 #endif
425
426 std::vector<int64> dimensions;
427 {
428 auto attr = reduce->getAttrOfType<mlir::DenseIntElementsAttr>("dimensions");
429 CHECK(attr);
430 absl::c_transform(
431 attr, std::back_inserter(dimensions),
432 std::function<int64(const llvm::APInt&)>(&llvm::APInt::getZExtValue));
433 }
434
435 std::vector<int64> dims_to_keep;
436 for (int64_t dim = 0; dim < operand_shape.dimensions().size(); ++dim) {
437 if (!absl::c_linear_search(dimensions, dim)) {
438 dims_to_keep.push_back(dim);
439 }
440 }
441
442 // We support fast codegen for three cases:
443 // 1) Row reduction: (K, R)
444 // 2) Column reduction: (K, R, K)
445 // 3) "Batched" row reduction: (R, K, R)
446 if (!LayoutUtil::AreDimensionsConsecutive(operand_shape.layout(),
447 dims_to_keep) &&
448 !LayoutUtil::AreDimensionsConsecutive(operand_shape.layout(),
449 dimensions)) {
450 return false;
451 }
452
453 return IsUnnestedReductionFasterThanElemental(
454 GetReductionKindAndContiguousComponentsImpl(operand_shape, dimensions));
455 }
456
IsInputFusibleSlices(mlir::Operation * unnested_hlo,bool verify_no_strides)457 bool IsInputFusibleSlices(mlir::Operation* unnested_hlo,
458 bool verify_no_strides) {
459 auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(unnested_hlo);
460 if (!fusion) {
461 return false;
462 }
463
464 auto is_non_strided = [](mlir::DenseIntElementsAttr strides) -> bool {
465 return absl::c_all_of(
466 strides, [](const llvm::APInt& stride) { return stride == 1; });
467 };
468
469 for (mlir::Value value : fusion.getFusionResults()) {
470 auto slice =
471 mlir::dyn_cast_or_null<mlir::mhlo::SliceOp>(value.getDefiningOp());
472 if (!slice) {
473 return false;
474 }
475 if (verify_no_strides && !is_non_strided(slice.strides())) {
476 return false;
477 }
478 }
479 return true;
480 }
481
GetReductionKindAndContiguousComponents(const HloInstruction & reduce)482 ReductionDimensions GetReductionKindAndContiguousComponents(
483 const HloInstruction& reduce) {
484 return GetReductionKindAndContiguousComponentsImpl(reduce.operand(0)->shape(),
485 reduce.dimensions());
486 }
487
GetReductionKindAndContiguousComponents(mlir::Operation * reduce)488 ReductionDimensions GetReductionKindAndContiguousComponents(
489 mlir::Operation* reduce) {
490 mlir::Value input = reduce->getOperand(0);
491 Shape operand_shape = GetShape(input);
492 std::vector<int64> dimensions;
493 {
494 auto attr = reduce->getAttrOfType<mlir::DenseIntElementsAttr>("dimensions");
495 CHECK(attr);
496 absl::c_transform(
497 attr, std::back_inserter(dimensions),
498 std::function<int64(const llvm::APInt&)>(&llvm::APInt::getZExtValue));
499 }
500 return GetReductionKindAndContiguousComponentsImpl(operand_shape, dimensions);
501 }
502
503 // This emits a device-side call to
504 // "i32 vprintf(i8* fmt, arguments_type* arguments)" in the driver; see
505 // http://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/index.html#system-calls
EmitPrintf(absl::string_view fmt,absl::Span<llvm::Value * const> arguments,llvm::IRBuilder<> * builder)506 llvm::Value* EmitPrintf(absl::string_view fmt,
507 absl::Span<llvm::Value* const> arguments,
508 llvm::IRBuilder<>* builder) {
509 std::vector<llvm::Type*> argument_types;
510
511 // Variadic arguments implicit promotion [1] converts float to double,
512 // and bool/char/short are converted to int.
513 // [1] https://en.cppreference.com/w/cpp/language/variadic_arguments
514 auto requires_int32_promotion = [](llvm::Type* type) {
515 return type->isIntegerTy(/*BitWidth=*/1) ||
516 type->isIntegerTy(/*BitWidth=*/8) ||
517 type->isIntegerTy(/*BitWidth=*/16);
518 };
519 auto requires_double_promotion = [](llvm::Type* type) {
520 return type->isFloatingPointTy();
521 };
522
523 for (auto argument : arguments) {
524 llvm::Type* type = argument->getType();
525 if (requires_double_promotion(type)) {
526 argument_types.push_back(builder->getDoubleTy());
527 } else if (requires_int32_promotion(type)) {
528 argument_types.push_back(builder->getInt32Ty());
529 } else {
530 argument_types.push_back(type);
531 }
532 }
533 auto* arguments_type = llvm::StructType::create(argument_types);
534 llvm::Value* arguments_ptr = builder->CreateAlloca(arguments_type);
535 for (size_t i = 0; i < arguments.size(); ++i) {
536 llvm::Value* value = arguments[i];
537 llvm::Type* type = value->getType();
538 if (requires_double_promotion(type)) {
539 value = builder->CreateFPCast(value, builder->getDoubleTy());
540 } else if (requires_int32_promotion(type)) {
541 value = builder->CreateIntCast(value, builder->getInt32Ty(),
542 /*isSigned=*/true);
543 }
544 builder->CreateStore(
545 value, builder->CreateGEP(arguments_ptr, {builder->getInt64(0),
546 builder->getInt32(i)}));
547 }
548 llvm::Type* ptr_ty = builder->getInt8Ty()->getPointerTo();
549 return builder->CreateCall(
550 builder->GetInsertBlock()->getParent()->getParent()->getOrInsertFunction(
551 "vprintf",
552 llvm::FunctionType::get(builder->getInt32Ty(), {ptr_ty, ptr_ty},
553 /*isVarArg=*/false)),
554 {builder->CreateGlobalStringPtr(llvm_ir::AsStringRef(fmt)),
555 builder->CreatePointerCast(arguments_ptr, ptr_ty)});
556 }
557
558 // Helper function to emit call to AMDGPU shfl_down function.
EmitAMDGPUShflDown(llvm::Value * value,llvm::Value * offset,llvm::IRBuilder<> * b)559 llvm::Value* EmitAMDGPUShflDown(llvm::Value* value, llvm::Value* offset,
560 llvm::IRBuilder<>* b) {
561 llvm::Module* module = b->GetInsertBlock()->getModule();
562 CHECK_EQ(value->getType()->getPrimitiveSizeInBits(), 32);
563 auto* i32_ty = b->getInt32Ty();
564 llvm::FunctionCallee shfl_fn = module->getOrInsertFunction(
565 llvm_ir::AsStringRef("__ockl_readuplane_i32"),
566 llvm::FunctionType::get(/*Result=*/i32_ty, {i32_ty, i32_ty},
567 /*isVarArg=*/false));
568 // AMDGPU device function requires first argument as i32.
569 llvm::Value* result =
570 b->CreateCall(shfl_fn, {b->CreateBitCast(value, i32_ty), offset});
571 // AMDGPU device function always returns an i32 type.
572 return b->CreateBitCast(result, value->getType());
573 }
574
575 // Helper function to emit call to NVPTX shfl_down intrinsic.
EmitNVPTXShflDown(llvm::Value * value,llvm::Value * offset,llvm::IRBuilder<> * b)576 llvm::Value* EmitNVPTXShflDown(llvm::Value* value, llvm::Value* offset,
577 llvm::IRBuilder<>* b) {
578 llvm::Module* module = b->GetInsertBlock()->getModule();
579 llvm::Intrinsic::ID llvm_intrinsic_id;
580 CHECK_EQ(value->getType()->getPrimitiveSizeInBits(), 32);
581 if (value->getType()->isFloatTy()) {
582 llvm_intrinsic_id = llvm::Intrinsic::nvvm_shfl_sync_down_f32;
583 } else {
584 llvm_intrinsic_id = llvm::Intrinsic::nvvm_shfl_sync_down_i32;
585 }
586 llvm::Function* intrinsic =
587 llvm::Intrinsic::getDeclaration(module, llvm_intrinsic_id, {});
588 return b->CreateCall(
589 intrinsic, {b->getInt32(-1), value, offset, b->getInt32(kWarpSize - 1)});
590 }
591
EmitFullWarpShuffleDown(llvm::Value * value,llvm::Value * offset,llvm::IRBuilder<> * builder)592 llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
593 llvm::IRBuilder<>* builder) {
594 int bit_width = value->getType()->getPrimitiveSizeInBits();
595 llvm::Module* module = builder->GetInsertBlock()->getModule();
596 llvm::Triple target_triple = llvm::Triple(module->getTargetTriple());
597
598 // Special case for efficiency
599 if (value->getType()->isFloatTy() && bit_width == 32) {
600 if (target_triple.isNVPTX()) {
601 return EmitNVPTXShflDown(value, offset, builder);
602 } else if (target_triple.getArch() == llvm::Triple::amdgcn) {
603 return EmitAMDGPUShflDown(value, offset, builder);
604 } else {
605 LOG(FATAL) << "Invalid triple " << target_triple.str();
606 }
607 }
608
609 // We must split values wider than 32 bits as the "shfl" instruction operates
610 // on 32-bit values.
611 int num_segments = CeilOfRatio(bit_width, 32);
612 llvm::Value* x = builder->CreateBitCast(
613 builder->CreateZExt(
614 builder->CreateBitCast(value, builder->getIntNTy(bit_width)),
615 builder->getIntNTy(32 * num_segments)),
616 llvm::VectorType::get(builder->getInt32Ty(), num_segments, false));
617 for (int i = 0; i < num_segments; ++i) {
618 llvm::Value* insert_val;
619 if (target_triple.isNVPTX()) {
620 insert_val = EmitNVPTXShflDown(builder->CreateExtractElement(x, i),
621 offset, builder);
622 } else if (target_triple.getArch() == llvm::Triple::amdgcn) {
623 insert_val = EmitAMDGPUShflDown(builder->CreateExtractElement(x, i),
624 offset, builder);
625 } else {
626 LOG(FATAL) << "Invalid triple " << target_triple.str();
627 }
628 x = builder->CreateInsertElement(x, insert_val, i);
629 }
630 return builder->CreateBitCast(
631 builder->CreateTrunc(
632 builder->CreateBitCast(x, builder->getIntNTy(32 * num_segments)),
633 builder->getIntNTy(bit_width)),
634 value->getType());
635 }
636
GetCudnnConvKind(const HloCustomCallInstruction * instr)637 StatusOr<CudnnConvKind> GetCudnnConvKind(
638 const HloCustomCallInstruction* instr) {
639 absl::string_view target = instr->custom_call_target();
640 if (target == kCudnnConvForwardCallTarget) {
641 return CudnnConvKind::kForward;
642 }
643 if (target == kCudnnConvBackwardInputCallTarget) {
644 return CudnnConvKind::kBackwardInput;
645 }
646 if (target == kCudnnConvBackwardFilterCallTarget) {
647 return CudnnConvKind::kBackwardFilter;
648 }
649 if (target == kCudnnConvBiasActivationForwardCallTarget) {
650 return CudnnConvKind::kForwardActivation;
651 }
652 return InternalError("Unexpected call target: %s", target);
653 }
654
CudnnConvKindToString(CudnnConvKind kind)655 string CudnnConvKindToString(CudnnConvKind kind) {
656 switch (kind) {
657 case CudnnConvKind::kForward:
658 return "forward";
659 case CudnnConvKind::kBackwardFilter:
660 return "backward_filter";
661 case CudnnConvKind::kBackwardInput:
662 return "backward_input";
663 case CudnnConvKind::kForwardActivation:
664 return "forward with activation";
665 }
666 }
667
IsBlock0Thread0(llvm::IRBuilder<> * b)668 llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b) {
669 llvm::Value* is_thread0 = b->CreateICmpEQ(
670 b->getInt32(0),
671 EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b));
672
673 llvm::Value* is_block0 = b->CreateICmpEQ(
674 b->getInt32(0),
675 EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b));
676 return b->CreateAnd(is_thread0, is_block0);
677 }
678
IsFusedReductionOutputConsistent(const HloInstruction * inst,const HloInstruction * first_reduce)679 bool IsFusedReductionOutputConsistent(const HloInstruction* inst,
680 const HloInstruction* first_reduce) {
681 if (IsReductionFromOrToContiguousDimensions(*inst)) {
682 // Shapes, layouts and dimensions must be the same for all reduces
683 // inside of this fusion.
684 // TODO(tjoerg): Relax the shape constraint. The datatype does not matter.
685 return ShapeUtil::Equal(first_reduce->shape(), inst->shape()) &&
686 ShapeUtil::Equal(first_reduce->operand(0)->shape(),
687 inst->operand(0)->shape()) &&
688 ShapeUtil::Equal(first_reduce->operand(1)->shape(),
689 inst->operand(1)->shape()) &&
690 first_reduce->dimensions() == inst->dimensions();
691 }
692 return ShapeUtil::CompatibleIgnoringElementType(
693 first_reduce->operand(0)->shape(), inst->shape()) &&
694 LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(),
695 inst->shape().layout());
696 }
697
IsFusedReductionOutputConsistent(mlir::mhlo::ReduceOp inst,mlir::mhlo::ReduceOp first_reduce,const FusionLayoutAnalysis & layout_analysis)698 bool IsFusedReductionOutputConsistent(
699 mlir::mhlo::ReduceOp inst, mlir::mhlo::ReduceOp first_reduce,
700 const FusionLayoutAnalysis& layout_analysis) {
701 CHECK_EQ(1, first_reduce.getNumResults());
702 Shape first_reduce_operand_shape =
703 layout_analysis.GetShape(first_reduce.inputs()[0]);
704 CHECK_EQ(1, inst.getNumResults());
705 Shape inst_shape = layout_analysis.GetShape(inst.getResult(0));
706
707 if (IsReductionFromOrToContiguousDimensions(inst, layout_analysis)) {
708 Shape first_reduce_shape =
709 layout_analysis.GetShape(first_reduce.getResult(0));
710 Shape first_reduce_init_shape =
711 layout_analysis.GetShape(first_reduce.init_values()[0]);
712
713 Shape inst_operand_shape = layout_analysis.GetShape(inst.inputs()[0]);
714 Shape inst_init_shape = layout_analysis.GetShape(inst.init_values()[0]);
715
716 // Shapes, layouts and dimensions must be the same for all reduces
717 // inside of this fusion.
718 // TODO(tjoerg): Relax the shape constraint. The datatype does not matter.
719 if (!(ShapeUtil::Equal(first_reduce_shape, inst_shape) &&
720 ShapeUtil::Equal(first_reduce_operand_shape, inst_operand_shape) &&
721 ShapeUtil::Equal(first_reduce_init_shape, inst_init_shape) &&
722 absl::c_equal(first_reduce.dimensions(), inst.dimensions()))) {
723 return false;
724 }
725 } else {
726 if (!(ShapeUtil::CompatibleIgnoringElementType(first_reduce_operand_shape,
727 inst_shape) &&
728 LayoutUtil::Equal(first_reduce_operand_shape.layout(),
729 inst_shape.layout()))) {
730 return false;
731 }
732 }
733 return true;
734 }
735
736 // Given an LMHLO op, returns the operand index of the first output operand.
737 //
738 // Notice that an operand alised to an output isn't an output, even though in
739 // that case WritesMlirBuffer() returns true on that operand.
740 //
741 // An operand is !WritesMlirBuffer() || equals (aliases) to a later operand. An
742 // output is the opposite, being both WritesMlirBuffer() and does not equal to
743 // any later operand.
PartitionLmhloOperandsAndOutputs(mlir::Operation * op)744 int PartitionLmhloOperandsAndOutputs(mlir::Operation* op) {
745 CHECK(op->getDialect() == op->getContext()->getLoadedDialect("lmhlo"));
746
747 int i;
748 for (i = op->getOperands().size() - 1; i >= 0; i--) {
749 const bool aliased =
750 std::find(op->getOperands().begin() + i + 1, op->getOperands().end(),
751 op->getOperand(i)) != op->getOperands().end();
752 if (!WritesMlirBuffer(op, op->getOperand(i)) || aliased) {
753 break;
754 }
755 }
756 return i + 1;
757 }
758
GetHloOperands(mlir::Operation * op)759 std::vector<mlir::Value> GetHloOperands(mlir::Operation* op) {
760 if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
761 return ToStdVector(fusion.getInputBuffers());
762 }
763 if (op->getDialect() == op->getContext()->getLoadedDialect("lmhlo")) {
764 int output_start = PartitionLmhloOperandsAndOutputs(op);
765 std::vector<mlir::Value> operands;
766 operands.reserve(output_start);
767 for (int i = 0; i < output_start; i++) {
768 operands.push_back(op->getOperand(i));
769 }
770 return operands;
771 }
772 if (op->getDialect() == op->getContext()->getLoadedDialect("mhlo")) {
773 return std::vector<mlir::Value>(op->getOperands().begin(),
774 op->getOperands().end());
775 }
776 LOG(FATAL) << "Unexpected op: " << MlirToString(op);
777 }
778
GetHloOutputs(mlir::Operation * op)779 std::vector<mlir::Value> GetHloOutputs(mlir::Operation* op) {
780 if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
781 return ToStdVector(fusion.getOutputBuffers());
782 }
783 if (op->getDialect() == op->getContext()->getLoadedDialect("lmhlo")) {
784 int output_start = PartitionLmhloOperandsAndOutputs(op);
785 std::vector<mlir::Value> outputs;
786 for (int i = output_start; i < op->getNumOperands(); i++) {
787 outputs.push_back(op->getOperand(i));
788 }
789 return outputs;
790 }
791 if (op->getDialect() == op->getContext()->getLoadedDialect("mhlo")) {
792 return std::vector<mlir::Value>(op->getResults().begin(),
793 op->getResults().end());
794 }
795 LOG(FATAL) << "Unexpected op: " << MlirToString(op);
796 }
797
WritesMlirBuffer(mlir::Operation * op,mlir::Value operand)798 bool WritesMlirBuffer(mlir::Operation* op, mlir::Value operand) {
799 llvm::SmallVector<mlir::MemoryEffects::EffectInstance, 2> effects;
800 mlir::cast<mlir::MemoryEffectOpInterface>(op).getEffectsOnValue(operand,
801 effects);
802 return absl::c_any_of(
803 effects, [](const mlir::MemoryEffects::EffectInstance& instance) {
804 return mlir::isa<mlir::MemoryEffects::Write>(instance.getEffect());
805 });
806 }
807
GetMemRefSizeInBytes(mlir::MemRefType type)808 static int64_t GetMemRefSizeInBytes(mlir::MemRefType type) {
809 // For i1 memrefs, the underlying allocation is 8 bits.
810 if (type.getElementType().isInteger(/*width=*/1)) {
811 return type.getNumElements();
812 } else {
813 return type.getSizeInBits() / CHAR_BIT;
814 }
815 }
816
GetAllocationIndex(mlir::BlockArgument func_arg,std::string * constant_name)817 static int64_t GetAllocationIndex(mlir::BlockArgument func_arg,
818 std::string* constant_name) {
819 auto func_op =
820 mlir::cast<mlir::FuncOp>(func_arg.getParentRegion()->getParentOp());
821 if (constant_name) {
822 if (auto constant_name_attr = func_op.getArgAttrOfType<mlir::StringAttr>(
823 func_arg.getArgNumber(), "lmhlo.constant_name")) {
824 *constant_name = constant_name_attr.getValue().str();
825 }
826 }
827 return func_arg.getArgNumber();
828 }
829
GetAllocationSlice(mlir::Value v,absl::Span<const BufferAllocation> allocations,std::string * constant_name)830 StatusOr<BufferAllocation::Slice> GetAllocationSlice(
831 mlir::Value v, absl::Span<const BufferAllocation> allocations,
832 std::string* constant_name) {
833 if (constant_name) {
834 constant_name->clear();
835 }
836
837 int64_t size = GetMemRefSizeInBytes(v.getType().cast<mlir::MemRefType>());
838
839 // We match the following patterns here:
840 // base := ViewOp(arg) | get_global_memref (global_memref) | arg
841 // root := base | MemRefReinterpretCastOp(base)
842
843 if (auto cast = mlir::dyn_cast_or_null<mlir::memref::ReinterpretCastOp>(
844 v.getDefiningOp())) {
845 v = cast.getViewSource();
846 }
847 if (auto view =
848 mlir::dyn_cast_or_null<mlir::memref::ViewOp>(v.getDefiningOp())) {
849 TF_RET_CHECK(view.source().isa<mlir::BlockArgument>());
850
851 return BufferAllocation::Slice(
852 &allocations[GetAllocationIndex(
853 view.source().cast<mlir::BlockArgument>(), constant_name)],
854 mlir::cast<mlir::ConstantOp>(view.byte_shift().getDefiningOp())
855 .value()
856 .cast<mlir::IntegerAttr>()
857 .getValue()
858 .getSExtValue(),
859 size);
860 }
861 if (auto get_global = mlir::dyn_cast_or_null<mlir::memref::GetGlobalOp>(
862 v.getDefiningOp())) {
863 auto module = get_global->getParentOfType<mlir::ModuleOp>();
864 if (constant_name) {
865 *constant_name = get_global.name().str();
866 }
867 auto global = mlir::cast<mlir::memref::GlobalOp>(
868 module.lookupSymbol(get_global.name()));
869 int64_t index =
870 global->getAttrOfType<mlir::IntegerAttr>("lmhlo.alloc").getInt();
871 return BufferAllocation::Slice(&allocations[index], 0,
872 allocations[index].size());
873 }
874 if (auto arg = v.dyn_cast<mlir::BlockArgument>()) {
875 return BufferAllocation::Slice(
876 &allocations[GetAllocationIndex(arg, constant_name)], 0, size);
877 }
878
879 return Unimplemented(
880 "Operand has to be in the form of ViewOp(arg) or "
881 "StaticMemRefCastOp(ViewOp(arg)) or arg");
882 }
883
CanEmitFusedDynamicUpdateSliceInPlaceForGpu(mlir::lmhlo::FusionOp fusion,absl::Span<const BufferAllocation> allocations)884 bool CanEmitFusedDynamicUpdateSliceInPlaceForGpu(
885 mlir::lmhlo::FusionOp fusion,
886 absl::Span<const BufferAllocation> allocations) {
887 auto results = fusion.getFusionResults();
888 if (results.size() != 1) {
889 return false;
890 }
891 auto dus = mlir::dyn_cast<mlir::mhlo::DynamicUpdateSliceOp>(
892 results[0].getDefiningOp());
893 if (!dus) {
894 return false;
895 }
896
897 auto output_buffers = fusion.getOutputBuffers();
898 CHECK_EQ(1, output_buffers.size());
899 auto parameter =
900 mlir::dyn_cast<mlir::memref::TensorLoadOp>(dus.operand().getDefiningOp());
901
902 if (!parameter) {
903 return false;
904 }
905
906 auto maybe_lhs = GetAllocationSlice(parameter.memref(), allocations);
907 auto maybe_rhs = GetAllocationSlice(output_buffers[0], allocations);
908 return maybe_lhs.ok() && maybe_rhs.ok() && *maybe_lhs == *maybe_rhs;
909 }
910
GetShape(mlir::Value value)911 Shape GetShape(mlir::Value value) {
912 if (value.getType().isa<mlir::MemRefType>()) {
913 return TypeToShape(value.getType());
914 } else if (value.getType().isa<mlir::TensorType>()) {
915 return GetShapeFromTensorType(value);
916 } else if (value.getType().isa<mlir::TupleType>()) {
917 return TypeToShape(value.getType());
918 }
919 LOG(FATAL) << "Unexpected value type to get shape for";
920 return {};
921 }
922
ReductionIsRaceFree(const ReductionDimensions & reduction_dimensions,const std::array<int64_t,3> & reduction_tiling)923 bool ReductionIsRaceFree(const ReductionDimensions& reduction_dimensions,
924 const std::array<int64_t, 3>& reduction_tiling) {
925 return (reduction_dimensions.is_row_reduction &&
926 reduction_dimensions.dimensions[2] <=
927 kMinThreadsXRowReduction * reduction_tiling[2] &&
928 reduction_dimensions.dimensions[0] <=
929 kBatchedReductionRaceFreeBound) ||
930 (!reduction_dimensions.is_row_reduction &&
931 reduction_dimensions.dimensions[1] <=
932 kWarpSize * reduction_tiling[1]);
933 }
934
935 } // namespace gpu
936 } // namespace xla
937