• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
17 
18 #include <algorithm>
19 #include <array>
20 #include <vector>
21 
22 #include "llvm/IR/IntrinsicsNVPTX.h"
23 #include "llvm/IR/Module.h"
24 #include "tensorflow/compiler/xla/layout_util.h"
25 #include "tensorflow/compiler/xla/service/gpu/target_util.h"
26 #include "tensorflow/compiler/xla/service/hlo_computation.h"
27 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
28 #include "tensorflow/compiler/xla/service/hlo_module.h"
29 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
30 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
31 #include "tensorflow/compiler/xla/shape_util.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/compiler/xla/window_util.h"
34 #include "tensorflow/compiler/xla/xla_data.pb.h"
35 #include "tensorflow/core/platform/logging.h"
36 #include "tensorflow/core/platform/protobuf.h"
37 
38 namespace xla {
39 namespace gpu {
40 
41 namespace {
42 
43 // Return whether the given shape is rank 2 excluding the batch dimensions.
IsRank2(const Shape & shape,int64 batch_dimensions_size)44 bool IsRank2(const Shape& shape, int64 batch_dimensions_size) {
45   return shape.rank() == batch_dimensions_size + 2;
46 }
47 
48 // In a gemm operation where output = lhs * rhs, check whether the given shapes
49 // are valid for the operation.
AreValidGemmShapes(const Shape & lhs_shape,const Shape & rhs_shape,const Shape & output_shape,int64 batch_dimensions_size)50 bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
51                         const Shape& output_shape,
52                         int64 batch_dimensions_size) {
53   // The inputs and the output must
54   // 1) be matrices with no padding and a non-zero number of elements,
55   // 2) have an allowed element type.
56   PrimitiveType output_primitive_type = output_shape.element_type();
57   bool type_is_allowed =
58       (output_primitive_type == F16 || output_primitive_type == F32 ||
59        output_primitive_type == F64 || output_primitive_type == C64 ||
60        output_primitive_type == C128);
61   return type_is_allowed && IsRank2(lhs_shape, batch_dimensions_size) &&
62          IsRank2(rhs_shape, batch_dimensions_size) &&
63          IsRank2(output_shape, batch_dimensions_size) &&
64          !ShapeUtil::IsZeroElementArray(lhs_shape) &&
65          !ShapeUtil::IsZeroElementArray(rhs_shape);
66 }
67 
68 // Given a shape and a group of contiguous dimensions in the shape, returns
69 // a tuple of three values (major, middle, minor), where major is the size of
70 // the dimensions more major then the given dimensions, minor is the size of
71 // dimensions more minor then the given dimensions, and middle is the size of
72 // the given dimensions.
PartitionShapeByMiddleDimensions(const Shape & shape,absl::Span<const int64> dims_middle)73 std::array<int64, 3> PartitionShapeByMiddleDimensions(
74     const Shape& shape, absl::Span<const int64> dims_middle) {
75   CHECK(LayoutUtil::AreDimensionsConsecutive(shape.layout(), dims_middle));
76   std::array<int64, 3> values = {1, 1, 1};
77   enum Segment { kMajor = 0, kMiddle = 1, kMinor = 2 };
78   Segment cur_segment = kMinor;
79 
80   for (int64 cur_dim : LayoutUtil::MinorToMajor(shape)) {
81     if (cur_segment != kMajor) {
82       // Handle change of segments.
83       bool cur_dim_in_middle = absl::c_linear_search(dims_middle, cur_dim);
84       if (cur_segment == kMinor) {
85         if (cur_dim_in_middle) {
86           cur_segment = kMiddle;
87         }
88       } else if (cur_segment == kMiddle) {
89         if (!cur_dim_in_middle) {
90           cur_segment = kMajor;
91         }
92       }
93     }
94     values[cur_segment] *= shape.dimensions(cur_dim);
95   }
96   return values;
97 }
98 
99 }  // namespace
100 
IsMatrixMultiplication(const HloInstruction & dot)101 bool IsMatrixMultiplication(const HloInstruction& dot) {
102   if (dot.opcode() != HloOpcode::kDot) {
103     return false;
104   }
105   const Shape& lhs_shape = dot.operand(0)->shape();
106   const Shape& rhs_shape = dot.operand(1)->shape();
107   const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers();
108 
109   // If gemm can accept the operand shapes, use it rather than a custom
110   // kernel.
111   if (AreValidGemmShapes(lhs_shape, rhs_shape, dot.shape(),
112                          dim_numbers.lhs_batch_dimensions_size())) {
113     // The size of the reduction dimension should match. The shape inference
114     // guarantees this invariant, so the check here is for programming
115     // errors.
116     CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)),
117              rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0)));
118     return true;
119   }
120   return false;
121 }
122 
IsCublasGemm(const HloInstruction & hlo)123 bool IsCublasGemm(const HloInstruction& hlo) {
124   return hlo.opcode() == HloOpcode::kCustomCall &&
125          hlo.custom_call_target() == kGemmCallTarget;
126 }
127 
GetReductionTiling(const ReductionDimensions & reduction_dimensions)128 std::array<int64, 3> GetReductionTiling(
129     const ReductionDimensions& reduction_dimensions) {
130   if (reduction_dimensions.is_row_reduction) {
131     int64 tile_z = std::min(reduction_dimensions.dimensions[0], int64{8});
132     if (reduction_dimensions.dimensions[1] == 1) {
133       CHECK_EQ(reduction_dimensions.dimensions[0], 1);
134       return {tile_z, 1, 16};
135     }
136     if (reduction_dimensions.dimensions[2] % (kWarpSize * 64) == 0) {
137       return {tile_z, 1, 64};
138     }
139     return {tile_z, 1, 8};
140   }
141 
142   // Column reduction.
143   return {1, 128, 1};
144 }
145 
146 const char* const kCudnnBatchNormForwardInferenceCallTarget =
147     "__cudnn$batchNormalizationForwardInference";
148 const char* const kCudnnBatchNormForwardTrainingCallTarget =
149     "__cudnn$batchNormalizationForwardTraining";
150 const char* const kCudnnBatchNormBackwardCallTarget =
151     "__cudnn$batchNormalizationBackward";
152 
IsCustomCallToDnnBatchNorm(const HloInstruction & hlo)153 bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo) {
154   if (hlo.opcode() != HloOpcode::kCustomCall) {
155     return false;
156   }
157   const auto& target = hlo.custom_call_target();
158   return target == kCudnnBatchNormForwardInferenceCallTarget ||
159          target == kCudnnBatchNormForwardTrainingCallTarget ||
160          target == kCudnnBatchNormBackwardCallTarget;
161 }
162 
163 const char* const kGemmCallTarget = "__cublas$gemm";
164 const char* const kCudnnConvForwardCallTarget = "__cudnn$convForward";
165 const char* const kCudnnConvBackwardInputCallTarget =
166     "__cudnn$convBackwardInput";
167 const char* const kCudnnConvBackwardFilterCallTarget =
168     "__cudnn$convBackwardFilter";
169 const char* const kCudnnConvBiasActivationForwardCallTarget =
170     "__cudnn$convBiasActivationForward";
171 
IsCustomCallToDnnConvolution(const HloInstruction & hlo)172 bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) {
173   if (hlo.opcode() != HloOpcode::kCustomCall) {
174     return false;
175   }
176   const auto& target = hlo.custom_call_target();
177   return target == kCudnnConvForwardCallTarget ||
178          target == kCudnnConvBackwardInputCallTarget ||
179          target == kCudnnConvBackwardFilterCallTarget ||
180          target == kCudnnConvBiasActivationForwardCallTarget;
181 }
182 
183 const char* const kCusolverCholeskyCallTarget = "__cusolver$cholesky";
184 
IsCustomCallToCusolver(const HloInstruction & hlo)185 bool IsCustomCallToCusolver(const HloInstruction& hlo) {
186   if (hlo.opcode() != HloOpcode::kCustomCall) {
187     return false;
188   }
189   const auto& target = hlo.custom_call_target();
190   return target == kCusolverCholeskyCallTarget;
191 }
192 
ImplementedAsLibraryCall(const HloInstruction & hlo)193 bool ImplementedAsLibraryCall(const HloInstruction& hlo) {
194   return IsCublasGemm(hlo) || IsCustomCallToDnnBatchNorm(hlo) ||
195          IsCustomCallToDnnConvolution(hlo);
196 }
197 
IsReductionFromOrToContiguousDimensions(const HloInstruction & reduce)198 bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce) {
199   if (HloOpcode::kReduce != reduce.opcode()) {
200     return false;
201   }
202 
203   // TODO(b/129698548): Remove this check after fixing the bug.
204   if (reduce.shape().element_type() == C128) {
205     return false;
206   }
207 
208   const HloInstruction* input = reduce.operand(0);
209   std::vector<int64> dims_to_keep;
210   for (int64 dim = 0; dim < input->shape().dimensions().size(); ++dim) {
211     if (!absl::c_linear_search(reduce.dimensions(), dim)) {
212       dims_to_keep.push_back(dim);
213     }
214   }
215   if (!LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
216                                             dims_to_keep) &&
217       !LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
218                                             reduce.dimensions())) {
219     return false;
220   }
221 
222   ReductionDimensions reduction_dimensions =
223       GetReductionKindAndContiguousComponents(reduce);
224 
225   if (reduction_dimensions.is_row_reduction) {
226     // For row reduction, the tile block is 1 x tile_size_x, and we are reducing
227     // along tile_size_x which needs to be large enough to make the tiling
228     // implementation efficient.
229     return reduction_dimensions.dimensions[2] >= kWarpSize;
230   }
231 
232   // For column reduction, the tile block is tile_size_y x tile_size_x, and we
233   // are reducing along tile_size_y. Only tile_size_y needs to be
234   // large enough to make the tiling implementation efficient.
235   return reduction_dimensions.dimensions[1] >= kWarpSize;
236 }
237 
IsInputFusibleSlices(const HloInstruction & unnested_hlo,bool verify_no_strides)238 bool IsInputFusibleSlices(const HloInstruction& unnested_hlo,
239                           bool verify_no_strides) {
240   if (!unnested_hlo.IsInputFusion()) {
241     return false;
242   }
243 
244   auto is_non_strided = [](const std::vector<int64>& strides) -> bool {
245     return absl::c_all_of(strides, [](int stride) { return stride == 1; });
246   };
247 
248   const HloInstruction* root = unnested_hlo.fused_expression_root();
249   if (root->opcode() == HloOpcode::kSlice) {
250     return !verify_no_strides || is_non_strided(root->slice_strides());
251   }
252 
253   if (root->opcode() != HloOpcode::kTuple) {
254     return false;
255   }
256 
257   return absl::c_all_of(root->operands(), [&](const HloInstruction* instr) {
258     return instr->opcode() == HloOpcode::kSlice &&
259            (!verify_no_strides || is_non_strided(instr->slice_strides()));
260   });
261 }
262 
GetReductionKindAndContiguousComponents(const HloInstruction & reduce)263 ReductionDimensions GetReductionKindAndContiguousComponents(
264     const HloInstruction& reduce) {
265   const Shape& input_shape = reduce.operand(0)->shape();
266   absl::Span<const int64> dims_to_reduce = reduce.dimensions();
267   DimensionVector dims_to_keep;
268   for (int64 dim = 0; dim < input_shape.rank(); ++dim) {
269     if (!absl::c_linear_search(dims_to_reduce, dim)) {
270       dims_to_keep.push_back(dim);
271     }
272   }
273 
274   if (dims_to_keep.empty()) {
275     return {/*is_row_reduction=*/true,
276             {1, 1, ShapeUtil::ElementsIn(input_shape)}};
277   }
278 
279   if (LayoutUtil::AreDimensionsConsecutive(input_shape.layout(),
280                                            dims_to_keep)) {
281     std::array<int64, 3> shape_partition =
282         PartitionShapeByMiddleDimensions(input_shape, dims_to_keep);
283     if (shape_partition[1] == 1) {
284       return {/*is_row_reduction=*/true,
285               {1, 1, shape_partition[0] * shape_partition[2]}};
286     }
287     if (shape_partition[2] == 1) {
288       return {/*is_row_reduction=*/false,
289               {1, shape_partition[0], shape_partition[1]}};
290     }
291     return {/*is_row_reduction=*/true, shape_partition};
292   }
293 
294   std::array<int64, 3> shape_partition =
295       PartitionShapeByMiddleDimensions(input_shape, dims_to_reduce);
296 
297   if (shape_partition[2] == 1) {
298     return {/*is_row_reduction=*/true,
299             {1, shape_partition[0], shape_partition[1]}};
300   }
301   return {/*is_row_reduction=*/false, shape_partition};
302 }
303 
304 // This emits a device-side call to
305 // "i32 vprintf(i8* fmt, arguments_type* arguments)" in the driver; see
306 // http://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/index.html#system-calls
EmitPrintf(absl::string_view fmt,absl::Span<llvm::Value * const> arguments,llvm::IRBuilder<> * builder)307 llvm::Value* EmitPrintf(absl::string_view fmt,
308                         absl::Span<llvm::Value* const> arguments,
309                         llvm::IRBuilder<>* builder) {
310   std::vector<llvm::Type*> argument_types;
311 
312   // Variadic arguments implicit promotion [1] converts float to double,
313   // and bool/char/short are converted to int.
314   // [1] https://en.cppreference.com/w/cpp/language/variadic_arguments
315   auto requires_int32_promotion = [](llvm::Type* type) {
316     return type->isIntegerTy(/*BitWidth=*/1) ||
317            type->isIntegerTy(/*BitWidth=*/8) ||
318            type->isIntegerTy(/*BitWidth=*/16);
319   };
320   auto requires_double_promotion = [](llvm::Type* type) {
321     return type->isFloatingPointTy();
322   };
323 
324   for (auto argument : arguments) {
325     llvm::Type* type = argument->getType();
326     if (requires_double_promotion(type)) {
327       argument_types.push_back(builder->getDoubleTy());
328     } else if (requires_int32_promotion(type)) {
329       argument_types.push_back(builder->getInt32Ty());
330     } else {
331       argument_types.push_back(type);
332     }
333   }
334   auto* arguments_type = llvm::StructType::create(argument_types);
335   llvm::Value* arguments_ptr = builder->CreateAlloca(arguments_type);
336   for (size_t i = 0; i < arguments.size(); ++i) {
337     llvm::Value* value = arguments[i];
338     llvm::Type* type = value->getType();
339     if (requires_double_promotion(type)) {
340       value = builder->CreateFPCast(value, builder->getDoubleTy());
341     } else if (requires_int32_promotion(type)) {
342       value = builder->CreateIntCast(value, builder->getInt32Ty(),
343                                      /*isSigned=*/true);
344     }
345     builder->CreateStore(
346         value, builder->CreateGEP(arguments_ptr, {builder->getInt64(0),
347                                                   builder->getInt32(i)}));
348   }
349   llvm::Type* ptr_ty = builder->getInt8Ty()->getPointerTo();
350   return builder->CreateCall(
351       builder->GetInsertBlock()->getParent()->getParent()->getOrInsertFunction(
352           "vprintf",
353           llvm::FunctionType::get(builder->getInt32Ty(), {ptr_ty, ptr_ty},
354                                   /*isVarArg=*/false)),
355       {builder->CreateGlobalStringPtr(llvm_ir::AsStringRef(fmt)),
356        builder->CreatePointerCast(arguments_ptr, ptr_ty)});
357 }
358 
359 // Helper function to emit call to AMDGPU shfl_down function.
EmitAMDGPUShflDown(llvm::Value * value,llvm::Value * offset,llvm::IRBuilder<> * b)360 llvm::Value* EmitAMDGPUShflDown(llvm::Value* value, llvm::Value* offset,
361                                 llvm::IRBuilder<>* b) {
362   llvm::Module* module = b->GetInsertBlock()->getModule();
363   CHECK_EQ(value->getType()->getPrimitiveSizeInBits(), 32);
364   auto* i32_ty = b->getInt32Ty();
365   llvm::FunctionCallee shfl_fn = module->getOrInsertFunction(
366       llvm_ir::AsStringRef("__ockl_readuplane_i32"),
367       llvm::FunctionType::get(/*Result=*/i32_ty, {i32_ty, i32_ty},
368                               /*isVarArg=*/false));
369   // AMDGPU device function requires first argument as i32.
370   llvm::Value* result =
371       b->CreateCall(shfl_fn, {b->CreateBitCast(value, i32_ty), offset});
372   // AMDGPU device function always returns an i32 type.
373   return b->CreateBitCast(result, value->getType());
374 }
375 
376 // Helper function to emit call to NVPTX shfl_down intrinsic.
EmitNVPTXShflDown(llvm::Value * value,llvm::Value * offset,llvm::IRBuilder<> * b)377 llvm::Value* EmitNVPTXShflDown(llvm::Value* value, llvm::Value* offset,
378                                llvm::IRBuilder<>* b) {
379   llvm::Module* module = b->GetInsertBlock()->getModule();
380   llvm::Intrinsic::ID llvm_intrinsic_id;
381   CHECK_EQ(value->getType()->getPrimitiveSizeInBits(), 32);
382   if (value->getType()->isFloatTy()) {
383     llvm_intrinsic_id = llvm::Intrinsic::nvvm_shfl_sync_down_f32;
384   } else {
385     llvm_intrinsic_id = llvm::Intrinsic::nvvm_shfl_sync_down_i32;
386   }
387   llvm::Function* intrinsic =
388       llvm::Intrinsic::getDeclaration(module, llvm_intrinsic_id, {});
389   return b->CreateCall(
390       intrinsic, {b->getInt32(-1), value, offset, b->getInt32(kWarpSize - 1)});
391 }
392 
EmitFullWarpShuffleDown(llvm::Value * value,llvm::Value * offset,llvm::IRBuilder<> * builder)393 llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
394                                      llvm::IRBuilder<>* builder) {
395   int bit_width = value->getType()->getPrimitiveSizeInBits();
396   llvm::Module* module = builder->GetInsertBlock()->getModule();
397   llvm::Triple target_triple = llvm::Triple(module->getTargetTriple());
398 
399   // Special case for efficiency
400   if (value->getType()->isFloatTy() && bit_width == 32) {
401     if (target_triple.isNVPTX()) {
402       return EmitNVPTXShflDown(value, offset, builder);
403     } else if (target_triple.getArch() == llvm::Triple::amdgcn) {
404       return EmitAMDGPUShflDown(value, offset, builder);
405     } else {
406       LOG(FATAL) << "Invalid triple " << target_triple.str();
407     }
408   }
409 
410   // We must split values wider than 32 bits as the "shfl" instruction operates
411   // on 32-bit values.
412   int num_segments = CeilOfRatio(bit_width, 32);
413   llvm::Value* x = builder->CreateBitCast(
414       builder->CreateZExt(
415           builder->CreateBitCast(value, builder->getIntNTy(bit_width)),
416           builder->getIntNTy(32 * num_segments)),
417       llvm::VectorType::get(builder->getInt32Ty(), num_segments));
418   for (int i = 0; i < num_segments; ++i) {
419     llvm::Value* insert_val;
420     if (target_triple.isNVPTX()) {
421       insert_val = EmitNVPTXShflDown(builder->CreateExtractElement(x, i),
422                                      offset, builder);
423     } else if (target_triple.getArch() == llvm::Triple::amdgcn) {
424       insert_val = EmitAMDGPUShflDown(builder->CreateExtractElement(x, i),
425                                       offset, builder);
426     } else {
427       LOG(FATAL) << "Invalid triple " << target_triple.str();
428     }
429     x = builder->CreateInsertElement(x, insert_val, i);
430   }
431   return builder->CreateBitCast(
432       builder->CreateTrunc(
433           builder->CreateBitCast(x, builder->getIntNTy(32 * num_segments)),
434           builder->getIntNTy(bit_width)),
435       value->getType());
436 }
437 
GetCudnnConvKind(const HloCustomCallInstruction * instr)438 StatusOr<CudnnConvKind> GetCudnnConvKind(
439     const HloCustomCallInstruction* instr) {
440   absl::string_view target = instr->custom_call_target();
441   if (target == kCudnnConvForwardCallTarget) {
442     return CudnnConvKind::kForward;
443   }
444   if (target == kCudnnConvBackwardInputCallTarget) {
445     return CudnnConvKind::kBackwardInput;
446   }
447   if (target == kCudnnConvBackwardFilterCallTarget) {
448     return CudnnConvKind::kBackwardFilter;
449   }
450   if (target == kCudnnConvBiasActivationForwardCallTarget) {
451     return CudnnConvKind::kForwardActivation;
452   }
453   return InternalError("Unexpected call target: %s", target);
454 }
455 
GetDnnConvolutionKind(const HloCustomCallInstruction * instr)456 StatusOr<se::dnn::ConvolutionKind> GetDnnConvolutionKind(
457     const HloCustomCallInstruction* instr) {
458   absl::string_view target = instr->custom_call_target();
459   if (target == kCudnnConvForwardCallTarget) {
460     return se::dnn::ConvolutionKind::FORWARD;
461   }
462   if (target == kCudnnConvBackwardInputCallTarget) {
463     return se::dnn::ConvolutionKind::BACKWARD_DATA;
464   }
465   if (target == kCudnnConvBackwardFilterCallTarget) {
466     return se::dnn::ConvolutionKind::BACKWARD_FILTER;
467   }
468   return InternalError("Unexpected call target: %s", target);
469 }
470 
GetDnnDataType(const HloCustomCallInstruction * conv)471 StatusOr<se::dnn::DataType> GetDnnDataType(
472     const HloCustomCallInstruction* conv) {
473   PrimitiveType output_primitive_type =
474       conv->shape().tuple_shapes(0).element_type();
475   switch (output_primitive_type) {
476     case F16:
477       return se::dnn::ToDataType<Eigen::half>::value;
478     case F32:
479       return se::dnn::ToDataType<float>::value;
480     case F64:
481       return se::dnn::ToDataType<double>::value;
482     default:
483       break;
484   }
485   return InternalError("Unsupported convolution datatype : %s",
486                        conv->ToString());
487 }
488 
CudnnConvKindToString(CudnnConvKind kind)489 string CudnnConvKindToString(CudnnConvKind kind) {
490   switch (kind) {
491     case CudnnConvKind::kForward:
492       return "forward";
493     case CudnnConvKind::kBackwardFilter:
494       return "backward_filter";
495     case CudnnConvKind::kBackwardInput:
496       return "backward_input";
497     case CudnnConvKind::kForwardActivation:
498       return "forward with activation";
499   }
500 }
501 
IsBlock0Thread0(llvm::IRBuilder<> * b)502 llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b) {
503   return b->CreateAnd(
504       b->CreateICmpEQ(
505           b->getInt32(0),
506           EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b)),
507       b->CreateICmpEQ(
508           b->getInt32(0),
509           EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b)));
510 }
511 
AreFusedReductionOutputsConsistent(absl::Span<const HloInstruction * const> output_instructions,const HloInstruction * first_reduce)512 bool AreFusedReductionOutputsConsistent(
513     absl::Span<const HloInstruction* const> output_instructions,
514     const HloInstruction* first_reduce) {
515   for (const HloInstruction* inst : output_instructions) {
516     if (IsReductionFromOrToContiguousDimensions(*inst)) {
517       // Shapes, layouts and dimensions must be the same for all reduces
518       // inside of this fusion.
519       // TODO(tjoerg): Relax the shape constraint. The datatype does not matter.
520       if (!(ShapeUtil::Equal(first_reduce->shape(), inst->shape()) &&
521             ShapeUtil::Equal(first_reduce->operand(0)->shape(),
522                              inst->operand(0)->shape()) &&
523             ShapeUtil::Equal(first_reduce->operand(1)->shape(),
524                              inst->operand(1)->shape()) &&
525             first_reduce->dimensions() == inst->dimensions())) {
526         return false;
527       }
528     } else {
529       if (!(ShapeUtil::CompatibleIgnoringElementType(
530                 first_reduce->operand(0)->shape(), inst->shape()) &&
531             LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(),
532                               inst->shape().layout()))) {
533         return false;
534       }
535     }
536   }
537   return true;
538 }
539 
540 }  // namespace gpu
541 }  // namespace xla
542