1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
17
18 #include <algorithm>
19 #include <array>
20 #include <vector>
21
22 #include "llvm/IR/IntrinsicsNVPTX.h"
23 #include "llvm/IR/Module.h"
24 #include "tensorflow/compiler/xla/layout_util.h"
25 #include "tensorflow/compiler/xla/service/gpu/target_util.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_module.h"
29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
30 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/compiler/xla/window_util.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/protobuf.h"
37
38 namespace xla {
39 namespace gpu {
40
41 namespace {
42
43 // Return whether the given shape is rank 2 excluding the batch dimensions.
IsRank2(const Shape & shape,int64 batch_dimensions_size)44 bool IsRank2(const Shape& shape, int64 batch_dimensions_size) {
45 return shape.rank() == batch_dimensions_size + 2;
46 }
47
48 // In a gemm operation where output = lhs * rhs, check whether the given shapes
49 // are valid for the operation.
AreValidGemmShapes(const Shape & lhs_shape,const Shape & rhs_shape,const Shape & output_shape,int64 batch_dimensions_size)50 bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
51 const Shape& output_shape,
52 int64 batch_dimensions_size) {
53 // The inputs and the output must
54 // 1) be matrices with no padding and a non-zero number of elements,
55 // 2) have an allowed element type.
56 PrimitiveType output_primitive_type = output_shape.element_type();
57 bool type_is_allowed =
58 (output_primitive_type == F16 || output_primitive_type == F32 ||
59 output_primitive_type == F64 || output_primitive_type == C64 ||
60 output_primitive_type == C128);
61 return type_is_allowed && IsRank2(lhs_shape, batch_dimensions_size) &&
62 IsRank2(rhs_shape, batch_dimensions_size) &&
63 IsRank2(output_shape, batch_dimensions_size) &&
64 !ShapeUtil::IsZeroElementArray(lhs_shape) &&
65 !ShapeUtil::IsZeroElementArray(rhs_shape);
66 }
67
68 // Given a shape and a group of contiguous dimensions in the shape, returns
69 // a tuple of three values (major, middle, minor), where major is the size of
70 // the dimensions more major then the given dimensions, minor is the size of
71 // dimensions more minor then the given dimensions, and middle is the size of
72 // the given dimensions.
PartitionShapeByMiddleDimensions(const Shape & shape,absl::Span<const int64> dims_middle)73 std::array<int64, 3> PartitionShapeByMiddleDimensions(
74 const Shape& shape, absl::Span<const int64> dims_middle) {
75 CHECK(LayoutUtil::AreDimensionsConsecutive(shape.layout(), dims_middle));
76 std::array<int64, 3> values = {1, 1, 1};
77 enum Segment { kMajor = 0, kMiddle = 1, kMinor = 2 };
78 Segment cur_segment = kMinor;
79
80 for (int64 cur_dim : LayoutUtil::MinorToMajor(shape)) {
81 if (cur_segment != kMajor) {
82 // Handle change of segments.
83 bool cur_dim_in_middle = absl::c_linear_search(dims_middle, cur_dim);
84 if (cur_segment == kMinor) {
85 if (cur_dim_in_middle) {
86 cur_segment = kMiddle;
87 }
88 } else if (cur_segment == kMiddle) {
89 if (!cur_dim_in_middle) {
90 cur_segment = kMajor;
91 }
92 }
93 }
94 values[cur_segment] *= shape.dimensions(cur_dim);
95 }
96 return values;
97 }
98
99 } // namespace
100
IsMatrixMultiplication(const HloInstruction & dot)101 bool IsMatrixMultiplication(const HloInstruction& dot) {
102 if (dot.opcode() != HloOpcode::kDot) {
103 return false;
104 }
105 const Shape& lhs_shape = dot.operand(0)->shape();
106 const Shape& rhs_shape = dot.operand(1)->shape();
107 const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers();
108
109 // If gemm can accept the operand shapes, use it rather than a custom
110 // kernel.
111 if (AreValidGemmShapes(lhs_shape, rhs_shape, dot.shape(),
112 dim_numbers.lhs_batch_dimensions_size())) {
113 // The size of the reduction dimension should match. The shape inference
114 // guarantees this invariant, so the check here is for programming
115 // errors.
116 CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)),
117 rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0)));
118 return true;
119 }
120 return false;
121 }
122
IsCublasGemm(const HloInstruction & hlo)123 bool IsCublasGemm(const HloInstruction& hlo) {
124 return hlo.opcode() == HloOpcode::kCustomCall &&
125 hlo.custom_call_target() == kGemmCallTarget;
126 }
127
GetReductionTiling(const ReductionDimensions & reduction_dimensions)128 std::array<int64, 3> GetReductionTiling(
129 const ReductionDimensions& reduction_dimensions) {
130 if (reduction_dimensions.is_row_reduction) {
131 int64 tile_z = std::min(reduction_dimensions.dimensions[0], int64{8});
132 if (reduction_dimensions.dimensions[1] == 1) {
133 CHECK_EQ(reduction_dimensions.dimensions[0], 1);
134 return {tile_z, 1, 16};
135 }
136 if (reduction_dimensions.dimensions[2] % (kWarpSize * 64) == 0) {
137 return {tile_z, 1, 64};
138 }
139 return {tile_z, 1, 8};
140 }
141
142 // Column reduction.
143 return {1, 128, 1};
144 }
145
146 const char* const kCudnnBatchNormForwardInferenceCallTarget =
147 "__cudnn$batchNormalizationForwardInference";
148 const char* const kCudnnBatchNormForwardTrainingCallTarget =
149 "__cudnn$batchNormalizationForwardTraining";
150 const char* const kCudnnBatchNormBackwardCallTarget =
151 "__cudnn$batchNormalizationBackward";
152
IsCustomCallToDnnBatchNorm(const HloInstruction & hlo)153 bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo) {
154 if (hlo.opcode() != HloOpcode::kCustomCall) {
155 return false;
156 }
157 const auto& target = hlo.custom_call_target();
158 return target == kCudnnBatchNormForwardInferenceCallTarget ||
159 target == kCudnnBatchNormForwardTrainingCallTarget ||
160 target == kCudnnBatchNormBackwardCallTarget;
161 }
162
163 const char* const kGemmCallTarget = "__cublas$gemm";
164 const char* const kCudnnConvForwardCallTarget = "__cudnn$convForward";
165 const char* const kCudnnConvBackwardInputCallTarget =
166 "__cudnn$convBackwardInput";
167 const char* const kCudnnConvBackwardFilterCallTarget =
168 "__cudnn$convBackwardFilter";
169 const char* const kCudnnConvBiasActivationForwardCallTarget =
170 "__cudnn$convBiasActivationForward";
171
IsCustomCallToDnnConvolution(const HloInstruction & hlo)172 bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) {
173 if (hlo.opcode() != HloOpcode::kCustomCall) {
174 return false;
175 }
176 const auto& target = hlo.custom_call_target();
177 return target == kCudnnConvForwardCallTarget ||
178 target == kCudnnConvBackwardInputCallTarget ||
179 target == kCudnnConvBackwardFilterCallTarget ||
180 target == kCudnnConvBiasActivationForwardCallTarget;
181 }
182
183 const char* const kCusolverCholeskyCallTarget = "__cusolver$cholesky";
184
IsCustomCallToCusolver(const HloInstruction & hlo)185 bool IsCustomCallToCusolver(const HloInstruction& hlo) {
186 if (hlo.opcode() != HloOpcode::kCustomCall) {
187 return false;
188 }
189 const auto& target = hlo.custom_call_target();
190 return target == kCusolverCholeskyCallTarget;
191 }
192
ImplementedAsLibraryCall(const HloInstruction & hlo)193 bool ImplementedAsLibraryCall(const HloInstruction& hlo) {
194 return IsCublasGemm(hlo) || IsCustomCallToDnnBatchNorm(hlo) ||
195 IsCustomCallToDnnConvolution(hlo);
196 }
197
IsReductionFromOrToContiguousDimensions(const HloInstruction & reduce)198 bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce) {
199 if (HloOpcode::kReduce != reduce.opcode()) {
200 return false;
201 }
202
203 // TODO(b/129698548): Remove this check after fixing the bug.
204 if (reduce.shape().element_type() == C128) {
205 return false;
206 }
207
208 const HloInstruction* input = reduce.operand(0);
209 std::vector<int64> dims_to_keep;
210 for (int64 dim = 0; dim < input->shape().dimensions().size(); ++dim) {
211 if (!absl::c_linear_search(reduce.dimensions(), dim)) {
212 dims_to_keep.push_back(dim);
213 }
214 }
215 if (!LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
216 dims_to_keep) &&
217 !LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
218 reduce.dimensions())) {
219 return false;
220 }
221
222 ReductionDimensions reduction_dimensions =
223 GetReductionKindAndContiguousComponents(reduce);
224
225 if (reduction_dimensions.is_row_reduction) {
226 // For row reduction, the tile block is 1 x tile_size_x, and we are reducing
227 // along tile_size_x which needs to be large enough to make the tiling
228 // implementation efficient.
229 return reduction_dimensions.dimensions[2] >= kWarpSize;
230 }
231
232 // For column reduction, the tile block is tile_size_y x tile_size_x, and we
233 // are reducing along tile_size_y. Only tile_size_y needs to be
234 // large enough to make the tiling implementation efficient.
235 return reduction_dimensions.dimensions[1] >= kWarpSize;
236 }
237
IsInputFusibleSlices(const HloInstruction & unnested_hlo,bool verify_no_strides)238 bool IsInputFusibleSlices(const HloInstruction& unnested_hlo,
239 bool verify_no_strides) {
240 if (!unnested_hlo.IsInputFusion()) {
241 return false;
242 }
243
244 auto is_non_strided = [](const std::vector<int64>& strides) -> bool {
245 return absl::c_all_of(strides, [](int stride) { return stride == 1; });
246 };
247
248 const HloInstruction* root = unnested_hlo.fused_expression_root();
249 if (root->opcode() == HloOpcode::kSlice) {
250 return !verify_no_strides || is_non_strided(root->slice_strides());
251 }
252
253 if (root->opcode() != HloOpcode::kTuple) {
254 return false;
255 }
256
257 return absl::c_all_of(root->operands(), [&](const HloInstruction* instr) {
258 return instr->opcode() == HloOpcode::kSlice &&
259 (!verify_no_strides || is_non_strided(instr->slice_strides()));
260 });
261 }
262
GetReductionKindAndContiguousComponents(const HloInstruction & reduce)263 ReductionDimensions GetReductionKindAndContiguousComponents(
264 const HloInstruction& reduce) {
265 const Shape& input_shape = reduce.operand(0)->shape();
266 absl::Span<const int64> dims_to_reduce = reduce.dimensions();
267 DimensionVector dims_to_keep;
268 for (int64 dim = 0; dim < input_shape.rank(); ++dim) {
269 if (!absl::c_linear_search(dims_to_reduce, dim)) {
270 dims_to_keep.push_back(dim);
271 }
272 }
273
274 if (dims_to_keep.empty()) {
275 return {/*is_row_reduction=*/true,
276 {1, 1, ShapeUtil::ElementsIn(input_shape)}};
277 }
278
279 if (LayoutUtil::AreDimensionsConsecutive(input_shape.layout(),
280 dims_to_keep)) {
281 std::array<int64, 3> shape_partition =
282 PartitionShapeByMiddleDimensions(input_shape, dims_to_keep);
283 if (shape_partition[1] == 1) {
284 return {/*is_row_reduction=*/true,
285 {1, 1, shape_partition[0] * shape_partition[2]}};
286 }
287 if (shape_partition[2] == 1) {
288 return {/*is_row_reduction=*/false,
289 {1, shape_partition[0], shape_partition[1]}};
290 }
291 return {/*is_row_reduction=*/true, shape_partition};
292 }
293
294 std::array<int64, 3> shape_partition =
295 PartitionShapeByMiddleDimensions(input_shape, dims_to_reduce);
296
297 if (shape_partition[2] == 1) {
298 return {/*is_row_reduction=*/true,
299 {1, shape_partition[0], shape_partition[1]}};
300 }
301 return {/*is_row_reduction=*/false, shape_partition};
302 }
303
304 // This emits a device-side call to
305 // "i32 vprintf(i8* fmt, arguments_type* arguments)" in the driver; see
306 // http://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/index.html#system-calls
EmitPrintf(absl::string_view fmt,absl::Span<llvm::Value * const> arguments,llvm::IRBuilder<> * builder)307 llvm::Value* EmitPrintf(absl::string_view fmt,
308 absl::Span<llvm::Value* const> arguments,
309 llvm::IRBuilder<>* builder) {
310 std::vector<llvm::Type*> argument_types;
311
312 // Variadic arguments implicit promotion [1] converts float to double,
313 // and bool/char/short are converted to int.
314 // [1] https://en.cppreference.com/w/cpp/language/variadic_arguments
315 auto requires_int32_promotion = [](llvm::Type* type) {
316 return type->isIntegerTy(/*BitWidth=*/1) ||
317 type->isIntegerTy(/*BitWidth=*/8) ||
318 type->isIntegerTy(/*BitWidth=*/16);
319 };
320 auto requires_double_promotion = [](llvm::Type* type) {
321 return type->isFloatingPointTy();
322 };
323
324 for (auto argument : arguments) {
325 llvm::Type* type = argument->getType();
326 if (requires_double_promotion(type)) {
327 argument_types.push_back(builder->getDoubleTy());
328 } else if (requires_int32_promotion(type)) {
329 argument_types.push_back(builder->getInt32Ty());
330 } else {
331 argument_types.push_back(type);
332 }
333 }
334 auto* arguments_type = llvm::StructType::create(argument_types);
335 llvm::Value* arguments_ptr = builder->CreateAlloca(arguments_type);
336 for (size_t i = 0; i < arguments.size(); ++i) {
337 llvm::Value* value = arguments[i];
338 llvm::Type* type = value->getType();
339 if (requires_double_promotion(type)) {
340 value = builder->CreateFPCast(value, builder->getDoubleTy());
341 } else if (requires_int32_promotion(type)) {
342 value = builder->CreateIntCast(value, builder->getInt32Ty(),
343 /*isSigned=*/true);
344 }
345 builder->CreateStore(
346 value, builder->CreateGEP(arguments_ptr, {builder->getInt64(0),
347 builder->getInt32(i)}));
348 }
349 llvm::Type* ptr_ty = builder->getInt8Ty()->getPointerTo();
350 return builder->CreateCall(
351 builder->GetInsertBlock()->getParent()->getParent()->getOrInsertFunction(
352 "vprintf",
353 llvm::FunctionType::get(builder->getInt32Ty(), {ptr_ty, ptr_ty},
354 /*isVarArg=*/false)),
355 {builder->CreateGlobalStringPtr(llvm_ir::AsStringRef(fmt)),
356 builder->CreatePointerCast(arguments_ptr, ptr_ty)});
357 }
358
359 // Helper function to emit call to AMDGPU shfl_down function.
EmitAMDGPUShflDown(llvm::Value * value,llvm::Value * offset,llvm::IRBuilder<> * b)360 llvm::Value* EmitAMDGPUShflDown(llvm::Value* value, llvm::Value* offset,
361 llvm::IRBuilder<>* b) {
362 llvm::Module* module = b->GetInsertBlock()->getModule();
363 CHECK_EQ(value->getType()->getPrimitiveSizeInBits(), 32);
364 auto* i32_ty = b->getInt32Ty();
365 llvm::FunctionCallee shfl_fn = module->getOrInsertFunction(
366 llvm_ir::AsStringRef("__ockl_readuplane_i32"),
367 llvm::FunctionType::get(/*Result=*/i32_ty, {i32_ty, i32_ty},
368 /*isVarArg=*/false));
369 // AMDGPU device function requires first argument as i32.
370 llvm::Value* result =
371 b->CreateCall(shfl_fn, {b->CreateBitCast(value, i32_ty), offset});
372 // AMDGPU device function always returns an i32 type.
373 return b->CreateBitCast(result, value->getType());
374 }
375
376 // Helper function to emit call to NVPTX shfl_down intrinsic.
EmitNVPTXShflDown(llvm::Value * value,llvm::Value * offset,llvm::IRBuilder<> * b)377 llvm::Value* EmitNVPTXShflDown(llvm::Value* value, llvm::Value* offset,
378 llvm::IRBuilder<>* b) {
379 llvm::Module* module = b->GetInsertBlock()->getModule();
380 llvm::Intrinsic::ID llvm_intrinsic_id;
381 CHECK_EQ(value->getType()->getPrimitiveSizeInBits(), 32);
382 if (value->getType()->isFloatTy()) {
383 llvm_intrinsic_id = llvm::Intrinsic::nvvm_shfl_sync_down_f32;
384 } else {
385 llvm_intrinsic_id = llvm::Intrinsic::nvvm_shfl_sync_down_i32;
386 }
387 llvm::Function* intrinsic =
388 llvm::Intrinsic::getDeclaration(module, llvm_intrinsic_id, {});
389 return b->CreateCall(
390 intrinsic, {b->getInt32(-1), value, offset, b->getInt32(kWarpSize - 1)});
391 }
392
EmitFullWarpShuffleDown(llvm::Value * value,llvm::Value * offset,llvm::IRBuilder<> * builder)393 llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
394 llvm::IRBuilder<>* builder) {
395 int bit_width = value->getType()->getPrimitiveSizeInBits();
396 llvm::Module* module = builder->GetInsertBlock()->getModule();
397 llvm::Triple target_triple = llvm::Triple(module->getTargetTriple());
398
399 // Special case for efficiency
400 if (value->getType()->isFloatTy() && bit_width == 32) {
401 if (target_triple.isNVPTX()) {
402 return EmitNVPTXShflDown(value, offset, builder);
403 } else if (target_triple.getArch() == llvm::Triple::amdgcn) {
404 return EmitAMDGPUShflDown(value, offset, builder);
405 } else {
406 LOG(FATAL) << "Invalid triple " << target_triple.str();
407 }
408 }
409
410 // We must split values wider than 32 bits as the "shfl" instruction operates
411 // on 32-bit values.
412 int num_segments = CeilOfRatio(bit_width, 32);
413 llvm::Value* x = builder->CreateBitCast(
414 builder->CreateZExt(
415 builder->CreateBitCast(value, builder->getIntNTy(bit_width)),
416 builder->getIntNTy(32 * num_segments)),
417 llvm::VectorType::get(builder->getInt32Ty(), num_segments));
418 for (int i = 0; i < num_segments; ++i) {
419 llvm::Value* insert_val;
420 if (target_triple.isNVPTX()) {
421 insert_val = EmitNVPTXShflDown(builder->CreateExtractElement(x, i),
422 offset, builder);
423 } else if (target_triple.getArch() == llvm::Triple::amdgcn) {
424 insert_val = EmitAMDGPUShflDown(builder->CreateExtractElement(x, i),
425 offset, builder);
426 } else {
427 LOG(FATAL) << "Invalid triple " << target_triple.str();
428 }
429 x = builder->CreateInsertElement(x, insert_val, i);
430 }
431 return builder->CreateBitCast(
432 builder->CreateTrunc(
433 builder->CreateBitCast(x, builder->getIntNTy(32 * num_segments)),
434 builder->getIntNTy(bit_width)),
435 value->getType());
436 }
437
GetCudnnConvKind(const HloCustomCallInstruction * instr)438 StatusOr<CudnnConvKind> GetCudnnConvKind(
439 const HloCustomCallInstruction* instr) {
440 absl::string_view target = instr->custom_call_target();
441 if (target == kCudnnConvForwardCallTarget) {
442 return CudnnConvKind::kForward;
443 }
444 if (target == kCudnnConvBackwardInputCallTarget) {
445 return CudnnConvKind::kBackwardInput;
446 }
447 if (target == kCudnnConvBackwardFilterCallTarget) {
448 return CudnnConvKind::kBackwardFilter;
449 }
450 if (target == kCudnnConvBiasActivationForwardCallTarget) {
451 return CudnnConvKind::kForwardActivation;
452 }
453 return InternalError("Unexpected call target: %s", target);
454 }
455
GetDnnConvolutionKind(const HloCustomCallInstruction * instr)456 StatusOr<se::dnn::ConvolutionKind> GetDnnConvolutionKind(
457 const HloCustomCallInstruction* instr) {
458 absl::string_view target = instr->custom_call_target();
459 if (target == kCudnnConvForwardCallTarget) {
460 return se::dnn::ConvolutionKind::FORWARD;
461 }
462 if (target == kCudnnConvBackwardInputCallTarget) {
463 return se::dnn::ConvolutionKind::BACKWARD_DATA;
464 }
465 if (target == kCudnnConvBackwardFilterCallTarget) {
466 return se::dnn::ConvolutionKind::BACKWARD_FILTER;
467 }
468 return InternalError("Unexpected call target: %s", target);
469 }
470
GetDnnDataType(const HloCustomCallInstruction * conv)471 StatusOr<se::dnn::DataType> GetDnnDataType(
472 const HloCustomCallInstruction* conv) {
473 PrimitiveType output_primitive_type =
474 conv->shape().tuple_shapes(0).element_type();
475 switch (output_primitive_type) {
476 case F16:
477 return se::dnn::ToDataType<Eigen::half>::value;
478 case F32:
479 return se::dnn::ToDataType<float>::value;
480 case F64:
481 return se::dnn::ToDataType<double>::value;
482 default:
483 break;
484 }
485 return InternalError("Unsupported convolution datatype : %s",
486 conv->ToString());
487 }
488
CudnnConvKindToString(CudnnConvKind kind)489 string CudnnConvKindToString(CudnnConvKind kind) {
490 switch (kind) {
491 case CudnnConvKind::kForward:
492 return "forward";
493 case CudnnConvKind::kBackwardFilter:
494 return "backward_filter";
495 case CudnnConvKind::kBackwardInput:
496 return "backward_input";
497 case CudnnConvKind::kForwardActivation:
498 return "forward with activation";
499 }
500 }
501
IsBlock0Thread0(llvm::IRBuilder<> * b)502 llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b) {
503 return b->CreateAnd(
504 b->CreateICmpEQ(
505 b->getInt32(0),
506 EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b)),
507 b->CreateICmpEQ(
508 b->getInt32(0),
509 EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b)));
510 }
511
AreFusedReductionOutputsConsistent(absl::Span<const HloInstruction * const> output_instructions,const HloInstruction * first_reduce)512 bool AreFusedReductionOutputsConsistent(
513 absl::Span<const HloInstruction* const> output_instructions,
514 const HloInstruction* first_reduce) {
515 for (const HloInstruction* inst : output_instructions) {
516 if (IsReductionFromOrToContiguousDimensions(*inst)) {
517 // Shapes, layouts and dimensions must be the same for all reduces
518 // inside of this fusion.
519 // TODO(tjoerg): Relax the shape constraint. The datatype does not matter.
520 if (!(ShapeUtil::Equal(first_reduce->shape(), inst->shape()) &&
521 ShapeUtil::Equal(first_reduce->operand(0)->shape(),
522 inst->operand(0)->shape()) &&
523 ShapeUtil::Equal(first_reduce->operand(1)->shape(),
524 inst->operand(1)->shape()) &&
525 first_reduce->dimensions() == inst->dimensions())) {
526 return false;
527 }
528 } else {
529 if (!(ShapeUtil::CompatibleIgnoringElementType(
530 first_reduce->operand(0)->shape(), inst->shape()) &&
531 LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(),
532 inst->shape().layout()))) {
533 return false;
534 }
535 }
536 }
537 return true;
538 }
539
540 } // namespace gpu
541 } // namespace xla
542