• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
17 
18 #include <algorithm>
19 #include <array>
20 #include <vector>
21 
22 #include "llvm/IR/IntrinsicsNVPTX.h"
23 #include "llvm/IR/Module.h"
24 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
25 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
26 #include "tensorflow/compiler/mlir/xla/hlo_utils.h"
27 #include "tensorflow/compiler/mlir/xla/mlir_hlo_to_hlo.h"
28 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
29 #include "tensorflow/compiler/xla/layout_util.h"
30 #include "tensorflow/compiler/xla/service/gpu/target_util.h"
31 #include "tensorflow/compiler/xla/service/hlo_computation.h"
32 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
33 #include "tensorflow/compiler/xla/service/hlo_module.h"
34 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
35 #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
36 #include "tensorflow/compiler/xla/shape_util.h"
37 #include "tensorflow/compiler/xla/util.h"
38 #include "tensorflow/compiler/xla/window_util.h"
39 #include "tensorflow/compiler/xla/xla_data.pb.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/platform/protobuf.h"
42 #include "tensorflow/stream_executor/device_description.h"
43 
44 namespace xla {
45 namespace gpu {
46 
47 namespace {
48 
49 // Return whether the given shape is rank 2 excluding the batch dimensions.
IsRank2(const Shape & shape,int64_t batch_dimensions_size)50 bool IsRank2(const Shape& shape, int64_t batch_dimensions_size) {
51   return shape.rank() == batch_dimensions_size + 2;
52 }
53 
54 // Given a shape and a group of contiguous dimensions in the shape, returns
55 // a tuple of three values (major, middle, minor), where major is the size of
56 // the dimensions more major then the given dimensions, minor is the size of
57 // dimensions more minor then the given dimensions, and middle is the size of
58 // the given dimensions.
PartitionShapeByMiddleDimensions(const Shape & shape,absl::Span<const int64> dims_middle)59 std::array<int64, 3> PartitionShapeByMiddleDimensions(
60     const Shape& shape, absl::Span<const int64> dims_middle) {
61   CHECK(LayoutUtil::AreDimensionsConsecutive(shape.layout(), dims_middle));
62   std::array<int64, 3> values = {1, 1, 1};
63   enum Segment { kMajor = 0, kMiddle = 1, kMinor = 2 };
64   Segment cur_segment = kMinor;
65 
66   for (int64_t cur_dim : LayoutUtil::MinorToMajor(shape)) {
67     if (cur_segment != kMajor) {
68       // Handle change of segments.
69       bool cur_dim_in_middle = absl::c_linear_search(dims_middle, cur_dim);
70       if (cur_segment == kMinor) {
71         if (cur_dim_in_middle) {
72           cur_segment = kMiddle;
73         }
74       } else if (cur_segment == kMiddle) {
75         if (!cur_dim_in_middle) {
76           cur_segment = kMajor;
77         }
78       }
79     }
80     values[cur_segment] *= shape.dimensions(cur_dim);
81   }
82   return values;
83 }
84 
GetShapeFromTensorType(mlir::Value value)85 Shape GetShapeFromTensorType(mlir::Value value) {
86   constexpr char kDefaultLayoutAttrName[] = "minor_to_major";
87 
88   mlir::Operation* op = value.getDefiningOp();
89   CHECK(op);
90   CHECK(value.getType().isa<mlir::TensorType>());
91   Shape shape = TypeToShape(value.getType());
92   if (auto attr = op->getAttrOfType<mlir::DenseIntElementsAttr>(
93           kDefaultLayoutAttrName)) {
94     std::vector<int64> minor_to_major;
95     absl::c_transform(
96         attr, std::back_inserter(minor_to_major),
97         std::function<int64(const llvm::APInt&)>(&llvm::APInt::getZExtValue));
98     *shape.mutable_layout() = LayoutUtil::MakeLayout(minor_to_major);
99   } else {
100     *shape.mutable_layout() = LayoutUtil::MakeDescendingLayout(
101         value.getType().cast<mlir::ShapedType>().getShape().size());
102   }
103   return shape;
104 }
105 
106 }  // namespace
107 
IsMatrixMultiplication(const HloInstruction & dot)108 bool IsMatrixMultiplication(const HloInstruction& dot) {
109   if (dot.opcode() != HloOpcode::kDot) {
110     return false;
111   }
112   const Shape& lhs_shape = dot.operand(0)->shape();
113   const Shape& rhs_shape = dot.operand(1)->shape();
114   const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers();
115 
116   PrimitiveType output_primitive_type = dot.shape().element_type();
117   bool type_is_allowed =
118       (output_primitive_type == F16 || output_primitive_type == BF16 ||
119        output_primitive_type == F32 || output_primitive_type == F64 ||
120        output_primitive_type == C64 || output_primitive_type == C128) ||
121       (output_primitive_type == S32 && lhs_shape.element_type() == S8 &&
122        lhs_shape.element_type() == S8);
123   bool shapes_are_valid =
124       type_is_allowed &&
125       IsRank2(lhs_shape, dim_numbers.lhs_batch_dimensions_size()) &&
126       IsRank2(rhs_shape, dim_numbers.lhs_batch_dimensions_size()) &&
127       IsRank2(dot.shape(), dim_numbers.lhs_batch_dimensions_size()) &&
128       !ShapeUtil::IsZeroElementArray(lhs_shape) &&
129       !ShapeUtil::IsZeroElementArray(rhs_shape);
130 
131   if (!shapes_are_valid) {
132     return false;
133   }
134 
135   // The size of the reduction dimension should match. The shape inference
136   // guarantees this invariant, so the check here is for programming
137   // errors.
138   CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)),
139            rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0)));
140 
141   return true;
142 }
143 
IsCublasGemm(const HloInstruction & hlo)144 bool IsCublasGemm(const HloInstruction& hlo) {
145   return hlo.opcode() == HloOpcode::kCustomCall &&
146          hlo.custom_call_target() == kGemmCallTarget;
147 }
148 
GetReductionTiling(const ReductionDimensions & reduction_dimensions,int smallest_input_dtype_bits,se::CudaComputeCapability cuda_compute_capability)149 std::array<int64, 3> GetReductionTiling(
150     const ReductionDimensions& reduction_dimensions,
151     int smallest_input_dtype_bits,
152     se::CudaComputeCapability cuda_compute_capability) {
153   if (reduction_dimensions.is_row_reduction) {
154     int64_t tile_z = std::min(reduction_dimensions.dimensions[0],
155                               kBatchedReductionRaceFreeBound);
156     return {tile_z, 1, 64};
157   }
158 
159   // Column reduction.
160   return {1, 128, 1};
161 }
162 
163 const char* const kCudnnBatchNormForwardInferenceCallTarget =
164     "__cudnn$batchNormalizationForwardInference";
165 const char* const kCudnnBatchNormForwardTrainingCallTarget =
166     "__cudnn$batchNormalizationForwardTraining";
167 const char* const kCudnnBatchNormBackwardCallTarget =
168     "__cudnn$batchNormalizationBackward";
169 
IsCustomCallToDnnBatchNorm(const HloInstruction & hlo)170 bool IsCustomCallToDnnBatchNorm(const HloInstruction& hlo) {
171   if (hlo.opcode() != HloOpcode::kCustomCall) {
172     return false;
173   }
174   const auto& target = hlo.custom_call_target();
175   return target == kCudnnBatchNormForwardInferenceCallTarget ||
176          target == kCudnnBatchNormForwardTrainingCallTarget ||
177          target == kCudnnBatchNormBackwardCallTarget;
178 }
179 
180 const char* const kGemmCallTarget = "__cublas$gemm";
181 const char* const kCudnnConvForwardCallTarget = "__cudnn$convForward";
182 const char* const kCudnnConvBackwardInputCallTarget =
183     "__cudnn$convBackwardInput";
184 const char* const kCudnnConvBackwardFilterCallTarget =
185     "__cudnn$convBackwardFilter";
186 const char* const kCudnnConvBiasActivationForwardCallTarget =
187     "__cudnn$convBiasActivationForward";
188 
IsCustomCallToDnnConvolution(const HloInstruction & hlo)189 bool IsCustomCallToDnnConvolution(const HloInstruction& hlo) {
190   if (hlo.opcode() != HloOpcode::kCustomCall) {
191     return false;
192   }
193   const auto& target = hlo.custom_call_target();
194   return target == kCudnnConvForwardCallTarget ||
195          target == kCudnnConvBackwardInputCallTarget ||
196          target == kCudnnConvBackwardFilterCallTarget ||
197          target == kCudnnConvBiasActivationForwardCallTarget;
198 }
199 
200 const char* const kCusolverCholeskyCallTarget = "__cusolver$cholesky";
201 
IsCustomCallToCusolver(const HloInstruction & hlo)202 bool IsCustomCallToCusolver(const HloInstruction& hlo) {
203   if (hlo.opcode() != HloOpcode::kCustomCall) {
204     return false;
205   }
206   const auto& target = hlo.custom_call_target();
207   return target == kCusolverCholeskyCallTarget;
208 }
209 
ImplementedAsLibraryCall(const HloInstruction & hlo)210 bool ImplementedAsLibraryCall(const HloInstruction& hlo) {
211   return IsCublasGemm(hlo) || IsCustomCallToDnnBatchNorm(hlo) ||
212          IsCustomCallToDnnConvolution(hlo);
213 }
214 
GetReductionKindAndContiguousComponentsImpl(const Shape & input_shape,absl::Span<const int64> dims_to_reduce)215 static ReductionDimensions GetReductionKindAndContiguousComponentsImpl(
216     const Shape& input_shape, absl::Span<const int64> dims_to_reduce) {
217   DimensionVector dims_to_keep;
218   for (int64_t dim = 0; dim < input_shape.rank(); ++dim) {
219     if (!absl::c_linear_search(dims_to_reduce, dim)) {
220       dims_to_keep.push_back(dim);
221     }
222   }
223 
224   if (dims_to_keep.empty()) {
225     return {/*is_row_reduction=*/true,
226             {1, 1, ShapeUtil::ElementsIn(input_shape)}};
227   }
228 
229   if (LayoutUtil::AreDimensionsConsecutive(input_shape.layout(),
230                                            dims_to_keep)) {
231     std::array<int64, 3> shape_partition =
232         PartitionShapeByMiddleDimensions(input_shape, dims_to_keep);
233     if (shape_partition[1] == 1) {
234       return {/*is_row_reduction=*/true,
235               {1, 1, shape_partition[0] * shape_partition[2]}};
236     }
237     if (shape_partition[2] == 1) {
238       return {/*is_row_reduction=*/false,
239               {1, shape_partition[0], shape_partition[1]}};
240     }
241     return {/*is_row_reduction=*/true, shape_partition};
242   }
243 
244   std::array<int64, 3> shape_partition =
245       PartitionShapeByMiddleDimensions(input_shape, dims_to_reduce);
246 
247   if (shape_partition[2] == 1) {
248     return {/*is_row_reduction=*/true,
249             {1, shape_partition[0], shape_partition[1]}};
250   }
251   return {/*is_row_reduction=*/false, shape_partition};
252 }
253 
IsUnnestedReductionFasterThanElemental(const ReductionDimensions & reduction_dimensions)254 static bool IsUnnestedReductionFasterThanElemental(
255     const ReductionDimensions& reduction_dimensions) {
256   if (reduction_dimensions.is_row_reduction) {
257     // For row reduction, the tile block is 1 x tile_size_x, and we are reducing
258     // along tile_size_x which needs to be large enough to make the tiling
259     // implementation efficient.
260     return reduction_dimensions.dimensions[2] >= kWarpSize;
261   }
262 
263   // For column reduction, the tile block is tile_size_y x tile_size_x, and we
264   // are reducing along tile_size_y. Only tile_size_y needs to be
265   // large enough to make the tiling implementation efficient.
266   return reduction_dimensions.dimensions[1] >= kWarpSize;
267 }
268 
IsReductionFromOrToContiguousDimensions(const HloInstruction & reduce)269 bool IsReductionFromOrToContiguousDimensions(const HloInstruction& reduce) {
270   if (HloOpcode::kReduce != reduce.opcode()) {
271     return false;
272   }
273 
274   // TODO(b/129698548): Remove this check after fixing the bug.
275   if (reduce.shape().element_type() == C128) {
276     return false;
277   }
278 
279   const HloInstruction* input = reduce.operand(0);
280   std::vector<int64> dims_to_keep;
281   for (int64_t dim = 0; dim < input->shape().dimensions().size(); ++dim) {
282     if (!absl::c_linear_search(reduce.dimensions(), dim)) {
283       dims_to_keep.push_back(dim);
284     }
285   }
286 
287   // We support fast codegen for three cases:
288   // 1) Row reduction: (K, R)
289   // 2) Column reduction: (K, R, K)
290   // 3) "Batched" row reduction: (R, K, R)
291   if (!LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
292                                             dims_to_keep) &&
293       !LayoutUtil::AreDimensionsConsecutive(input->shape().layout(),
294                                             reduce.dimensions())) {
295     return false;
296   }
297 
298   return IsUnnestedReductionFasterThanElemental(
299       GetReductionKindAndContiguousComponents(reduce));
300 }
301 
302 // Constructs the fusion layout analysis object by using a heuristic to infer
303 // the layout of a fusion internal value. In general, if the value is derived
304 // from a fusion parameter (which by definition has a layout) using elementwise
305 // operations, it will inherit the layout of that parameter. OTOH if the value
306 // if written to a fusion output, it will inherit the layout of that output.
307 // If the heuristic fails, the default layout will be inferred.
FusionLayoutAnalysis(mlir::lmhlo::FusionOp fusion_op)308 FusionLayoutAnalysis::FusionLayoutAnalysis(mlir::lmhlo::FusionOp fusion_op) {
309   VLOG(3) << "Analyzing \n" << MlirToString(fusion_op);
310   auto add_layout = [this](mlir::Value v, const Layout& layout) {
311     layouts_[v] = layout;
312     VLOG(3) << "===============\n";
313     VLOG(3) << "For value \n" << MlirToString(v.getDefiningOp());
314     VLOG(3) << "Layout = " << layout.ToString() << "\n";
315     VLOG(3) << "===============\n";
316   };
317 
318   // Propagate layouts inside fusion region.
319   for (mlir::Operation& op : fusion_op.region().front().without_terminator()) {
320     if (auto load = mlir::dyn_cast<mlir::memref::TensorLoadOp>(op)) {
321       add_layout(load, GetShape(load.memref()).layout());
322     } else if (auto store = mlir::dyn_cast<mlir::memref::TensorStoreOp>(op)) {
323       // Propagate the stored memref layout to the value if it does not have a
324       // inferred layout already. This prefers load coalescing over stores.
325       if (layouts_.count(store.tensor()) == 0) {
326         add_layout(store.tensor(), GetShape(store.memref()).layout());
327       }
328     } else if (auto bitcast = mlir::dyn_cast<mlir::mhlo::BitcastOp>(op)) {
329       auto attr =
330           bitcast->getAttrOfType<mlir::DenseIntElementsAttr>("result_layout");
331       std::vector<int64> minor_to_major;
332       absl::c_transform(
333           attr, std::back_inserter(minor_to_major),
334           std::function<int64(const llvm::APInt&)>(&llvm::APInt::getZExtValue));
335       add_layout(bitcast, LayoutUtil::MakeLayout(minor_to_major));
336 
337       attr =
338           bitcast->getAttrOfType<mlir::DenseIntElementsAttr>("source_layout");
339       minor_to_major.clear();
340       absl::c_transform(
341           attr, std::back_inserter(minor_to_major),
342           std::function<int64(const llvm::APInt&)>(&llvm::APInt::getZExtValue));
343       add_layout(bitcast.operand(), LayoutUtil::MakeLayout(minor_to_major));
344     } else {
345       HloOpcode opcode = *xla::MhloToHloOpcode(&op);
346       if (!HloInstruction::IsOpElementwise(opcode)) {
347         continue;
348       }
349       // If any operand has a layout, infer that layout for the result of the
350       // operation. If 2 operands have a conflicting layout, we still need to
351       // choose one of them, so we will arbitrarily choose the first one.
352       for (mlir::Value operand : op.getOperands()) {
353         auto it = layouts_.find(operand);
354         if (it != layouts_.end()) {
355           // Do not pass in a reference to an entry in the map when adding a new
356           // entry. The map may expand when adding, and the reference may become
357           // invalid. To avoid this, create a local copy of the layout.
358           const Layout operand_layout = it->second;
359           add_layout(op.getResult(0), operand_layout);
360           break;
361         }
362       }
363     }
364   }
365 }
366 
GetShape(mlir::Value value) const367 Shape FusionLayoutAnalysis::GetShape(mlir::Value value) const {
368   Shape shape = TypeToShape(value.getType());
369   if (!value.getType().isa<mlir::MemRefType>()) {
370     auto it = layouts_.find(value);
371     if (it != layouts_.end()) {
372       *shape.mutable_layout() = it->second;
373     }
374   }
375   return shape;
376 }
377 
IsReductionFromOrToContiguousDimensions(mlir::Operation * reduce,const FusionLayoutAnalysis & layout_analysis)378 bool IsReductionFromOrToContiguousDimensions(
379     mlir::Operation* reduce, const FusionLayoutAnalysis& layout_analysis) {
380   if (!mlir::isa<mlir::mhlo::ReduceOp>(reduce)) {
381     return false;
382   }
383   std::vector<mlir::Value> results = GetHloOutputs(reduce);
384   CHECK_EQ(1, results.size());
385 
386   auto c128_type =
387       mlir::ComplexType::get(mlir::FloatType::getF64(reduce->getContext()));
388 
389   // TODO(b/129698548): Remove this check after fixing the bug.
390   if (results[0].getType().cast<mlir::ShapedType>().getElementType() ==
391       c128_type) {
392     return false;
393   }
394 
395   mlir::Value input = reduce->getOperand(0);
396   const Shape operand_shape = layout_analysis.GetShape(input);
397 
398   // Enable this code to check mismatch between the inferred layout and what was
399   // there before. Based on actual runs, some mismatches are expected.
400 #if 0
401   Shape operand_shape_ir = GetShape(input);
402   if (auto tensor_type = input.getType().dyn_cast<mlir::TensorType>()) {
403     if (auto attr = mlir::GetLayoutFromMlirHlo(input.getDefiningOp())) {
404       std::vector<int64> minor_to_major;
405       absl::c_transform(
406           attr, std::back_inserter(minor_to_major),
407           std::function<int64(const llvm::APInt&)>(&llvm::APInt::getZExtValue));
408       *operand_shape_ir.mutable_layout() =
409           LayoutUtil::MakeLayout(minor_to_major);
410     }
411   }
412   bool match = ShapeUtil::Equal(operand_shape, operand_shape_ir);
413   llvm::errs() << "inferred shape = " << operand_shape.ToString(true) << "\n";
414   llvm::errs() << "Actual shape in IR = " << operand_shape_ir.ToString(true)
415                << "\n";
416   if (!match) {
417     llvm::errs() << "Unable to infer layout for reduce op operand(0)\n";
418     llvm::errs() << "\nreduce = \n";
419     reduce->dump();
420     llvm::errs() << "\nparent = \n";
421     reduce->getParentOp()->dump();
422     CHECK(0);
423   }
424 #endif
425 
426   std::vector<int64> dimensions;
427   {
428     auto attr = reduce->getAttrOfType<mlir::DenseIntElementsAttr>("dimensions");
429     CHECK(attr);
430     absl::c_transform(
431         attr, std::back_inserter(dimensions),
432         std::function<int64(const llvm::APInt&)>(&llvm::APInt::getZExtValue));
433   }
434 
435   std::vector<int64> dims_to_keep;
436   for (int64_t dim = 0; dim < operand_shape.dimensions().size(); ++dim) {
437     if (!absl::c_linear_search(dimensions, dim)) {
438       dims_to_keep.push_back(dim);
439     }
440   }
441 
442   // We support fast codegen for three cases:
443   // 1) Row reduction: (K, R)
444   // 2) Column reduction: (K, R, K)
445   // 3) "Batched" row reduction: (R, K, R)
446   if (!LayoutUtil::AreDimensionsConsecutive(operand_shape.layout(),
447                                             dims_to_keep) &&
448       !LayoutUtil::AreDimensionsConsecutive(operand_shape.layout(),
449                                             dimensions)) {
450     return false;
451   }
452 
453   return IsUnnestedReductionFasterThanElemental(
454       GetReductionKindAndContiguousComponentsImpl(operand_shape, dimensions));
455 }
456 
IsInputFusibleSlices(mlir::Operation * unnested_hlo,bool verify_no_strides)457 bool IsInputFusibleSlices(mlir::Operation* unnested_hlo,
458                           bool verify_no_strides) {
459   auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(unnested_hlo);
460   if (!fusion) {
461     return false;
462   }
463 
464   auto is_non_strided = [](mlir::DenseIntElementsAttr strides) -> bool {
465     return absl::c_all_of(
466         strides, [](const llvm::APInt& stride) { return stride == 1; });
467   };
468 
469   for (mlir::Value value : fusion.getFusionResults()) {
470     auto slice =
471         mlir::dyn_cast_or_null<mlir::mhlo::SliceOp>(value.getDefiningOp());
472     if (!slice) {
473       return false;
474     }
475     if (verify_no_strides && !is_non_strided(slice.strides())) {
476       return false;
477     }
478   }
479   return true;
480 }
481 
GetReductionKindAndContiguousComponents(const HloInstruction & reduce)482 ReductionDimensions GetReductionKindAndContiguousComponents(
483     const HloInstruction& reduce) {
484   return GetReductionKindAndContiguousComponentsImpl(reduce.operand(0)->shape(),
485                                                      reduce.dimensions());
486 }
487 
GetReductionKindAndContiguousComponents(mlir::Operation * reduce)488 ReductionDimensions GetReductionKindAndContiguousComponents(
489     mlir::Operation* reduce) {
490   mlir::Value input = reduce->getOperand(0);
491   Shape operand_shape = GetShape(input);
492   std::vector<int64> dimensions;
493   {
494     auto attr = reduce->getAttrOfType<mlir::DenseIntElementsAttr>("dimensions");
495     CHECK(attr);
496     absl::c_transform(
497         attr, std::back_inserter(dimensions),
498         std::function<int64(const llvm::APInt&)>(&llvm::APInt::getZExtValue));
499   }
500   return GetReductionKindAndContiguousComponentsImpl(operand_shape, dimensions);
501 }
502 
503 // This emits a device-side call to
504 // "i32 vprintf(i8* fmt, arguments_type* arguments)" in the driver; see
505 // http://docs.nvidia.com/cuda/ptx-writers-guide-to-interoperability/index.html#system-calls
EmitPrintf(absl::string_view fmt,absl::Span<llvm::Value * const> arguments,llvm::IRBuilder<> * builder)506 llvm::Value* EmitPrintf(absl::string_view fmt,
507                         absl::Span<llvm::Value* const> arguments,
508                         llvm::IRBuilder<>* builder) {
509   std::vector<llvm::Type*> argument_types;
510 
511   // Variadic arguments implicit promotion [1] converts float to double,
512   // and bool/char/short are converted to int.
513   // [1] https://en.cppreference.com/w/cpp/language/variadic_arguments
514   auto requires_int32_promotion = [](llvm::Type* type) {
515     return type->isIntegerTy(/*BitWidth=*/1) ||
516            type->isIntegerTy(/*BitWidth=*/8) ||
517            type->isIntegerTy(/*BitWidth=*/16);
518   };
519   auto requires_double_promotion = [](llvm::Type* type) {
520     return type->isFloatingPointTy();
521   };
522 
523   for (auto argument : arguments) {
524     llvm::Type* type = argument->getType();
525     if (requires_double_promotion(type)) {
526       argument_types.push_back(builder->getDoubleTy());
527     } else if (requires_int32_promotion(type)) {
528       argument_types.push_back(builder->getInt32Ty());
529     } else {
530       argument_types.push_back(type);
531     }
532   }
533   auto* arguments_type = llvm::StructType::create(argument_types);
534   llvm::Value* arguments_ptr = builder->CreateAlloca(arguments_type);
535   for (size_t i = 0; i < arguments.size(); ++i) {
536     llvm::Value* value = arguments[i];
537     llvm::Type* type = value->getType();
538     if (requires_double_promotion(type)) {
539       value = builder->CreateFPCast(value, builder->getDoubleTy());
540     } else if (requires_int32_promotion(type)) {
541       value = builder->CreateIntCast(value, builder->getInt32Ty(),
542                                      /*isSigned=*/true);
543     }
544     builder->CreateStore(
545         value, builder->CreateGEP(arguments_ptr, {builder->getInt64(0),
546                                                   builder->getInt32(i)}));
547   }
548   llvm::Type* ptr_ty = builder->getInt8Ty()->getPointerTo();
549   return builder->CreateCall(
550       builder->GetInsertBlock()->getParent()->getParent()->getOrInsertFunction(
551           "vprintf",
552           llvm::FunctionType::get(builder->getInt32Ty(), {ptr_ty, ptr_ty},
553                                   /*isVarArg=*/false)),
554       {builder->CreateGlobalStringPtr(llvm_ir::AsStringRef(fmt)),
555        builder->CreatePointerCast(arguments_ptr, ptr_ty)});
556 }
557 
558 // Helper function to emit call to AMDGPU shfl_down function.
EmitAMDGPUShflDown(llvm::Value * value,llvm::Value * offset,llvm::IRBuilder<> * b)559 llvm::Value* EmitAMDGPUShflDown(llvm::Value* value, llvm::Value* offset,
560                                 llvm::IRBuilder<>* b) {
561   llvm::Module* module = b->GetInsertBlock()->getModule();
562   CHECK_EQ(value->getType()->getPrimitiveSizeInBits(), 32);
563   auto* i32_ty = b->getInt32Ty();
564   llvm::FunctionCallee shfl_fn = module->getOrInsertFunction(
565       llvm_ir::AsStringRef("__ockl_readuplane_i32"),
566       llvm::FunctionType::get(/*Result=*/i32_ty, {i32_ty, i32_ty},
567                               /*isVarArg=*/false));
568   // AMDGPU device function requires first argument as i32.
569   llvm::Value* result =
570       b->CreateCall(shfl_fn, {b->CreateBitCast(value, i32_ty), offset});
571   // AMDGPU device function always returns an i32 type.
572   return b->CreateBitCast(result, value->getType());
573 }
574 
575 // Helper function to emit call to NVPTX shfl_down intrinsic.
EmitNVPTXShflDown(llvm::Value * value,llvm::Value * offset,llvm::IRBuilder<> * b)576 llvm::Value* EmitNVPTXShflDown(llvm::Value* value, llvm::Value* offset,
577                                llvm::IRBuilder<>* b) {
578   llvm::Module* module = b->GetInsertBlock()->getModule();
579   llvm::Intrinsic::ID llvm_intrinsic_id;
580   CHECK_EQ(value->getType()->getPrimitiveSizeInBits(), 32);
581   if (value->getType()->isFloatTy()) {
582     llvm_intrinsic_id = llvm::Intrinsic::nvvm_shfl_sync_down_f32;
583   } else {
584     llvm_intrinsic_id = llvm::Intrinsic::nvvm_shfl_sync_down_i32;
585   }
586   llvm::Function* intrinsic =
587       llvm::Intrinsic::getDeclaration(module, llvm_intrinsic_id, {});
588   return b->CreateCall(
589       intrinsic, {b->getInt32(-1), value, offset, b->getInt32(kWarpSize - 1)});
590 }
591 
EmitFullWarpShuffleDown(llvm::Value * value,llvm::Value * offset,llvm::IRBuilder<> * builder)592 llvm::Value* EmitFullWarpShuffleDown(llvm::Value* value, llvm::Value* offset,
593                                      llvm::IRBuilder<>* builder) {
594   int bit_width = value->getType()->getPrimitiveSizeInBits();
595   llvm::Module* module = builder->GetInsertBlock()->getModule();
596   llvm::Triple target_triple = llvm::Triple(module->getTargetTriple());
597 
598   // Special case for efficiency
599   if (value->getType()->isFloatTy() && bit_width == 32) {
600     if (target_triple.isNVPTX()) {
601       return EmitNVPTXShflDown(value, offset, builder);
602     } else if (target_triple.getArch() == llvm::Triple::amdgcn) {
603       return EmitAMDGPUShflDown(value, offset, builder);
604     } else {
605       LOG(FATAL) << "Invalid triple " << target_triple.str();
606     }
607   }
608 
609   // We must split values wider than 32 bits as the "shfl" instruction operates
610   // on 32-bit values.
611   int num_segments = CeilOfRatio(bit_width, 32);
612   llvm::Value* x = builder->CreateBitCast(
613       builder->CreateZExt(
614           builder->CreateBitCast(value, builder->getIntNTy(bit_width)),
615           builder->getIntNTy(32 * num_segments)),
616       llvm::VectorType::get(builder->getInt32Ty(), num_segments, false));
617   for (int i = 0; i < num_segments; ++i) {
618     llvm::Value* insert_val;
619     if (target_triple.isNVPTX()) {
620       insert_val = EmitNVPTXShflDown(builder->CreateExtractElement(x, i),
621                                      offset, builder);
622     } else if (target_triple.getArch() == llvm::Triple::amdgcn) {
623       insert_val = EmitAMDGPUShflDown(builder->CreateExtractElement(x, i),
624                                       offset, builder);
625     } else {
626       LOG(FATAL) << "Invalid triple " << target_triple.str();
627     }
628     x = builder->CreateInsertElement(x, insert_val, i);
629   }
630   return builder->CreateBitCast(
631       builder->CreateTrunc(
632           builder->CreateBitCast(x, builder->getIntNTy(32 * num_segments)),
633           builder->getIntNTy(bit_width)),
634       value->getType());
635 }
636 
GetCudnnConvKind(const HloCustomCallInstruction * instr)637 StatusOr<CudnnConvKind> GetCudnnConvKind(
638     const HloCustomCallInstruction* instr) {
639   absl::string_view target = instr->custom_call_target();
640   if (target == kCudnnConvForwardCallTarget) {
641     return CudnnConvKind::kForward;
642   }
643   if (target == kCudnnConvBackwardInputCallTarget) {
644     return CudnnConvKind::kBackwardInput;
645   }
646   if (target == kCudnnConvBackwardFilterCallTarget) {
647     return CudnnConvKind::kBackwardFilter;
648   }
649   if (target == kCudnnConvBiasActivationForwardCallTarget) {
650     return CudnnConvKind::kForwardActivation;
651   }
652   return InternalError("Unexpected call target: %s", target);
653 }
654 
CudnnConvKindToString(CudnnConvKind kind)655 string CudnnConvKindToString(CudnnConvKind kind) {
656   switch (kind) {
657     case CudnnConvKind::kForward:
658       return "forward";
659     case CudnnConvKind::kBackwardFilter:
660       return "backward_filter";
661     case CudnnConvKind::kBackwardInput:
662       return "backward_input";
663     case CudnnConvKind::kForwardActivation:
664       return "forward with activation";
665   }
666 }
667 
IsBlock0Thread0(llvm::IRBuilder<> * b)668 llvm::Value* IsBlock0Thread0(llvm::IRBuilder<>* b) {
669   llvm::Value* is_thread0 = b->CreateICmpEQ(
670       b->getInt32(0),
671       EmitCallToTargetIntrinsic(TargetIntrinsicID::kThreadIdx, {}, {}, b));
672 
673   llvm::Value* is_block0 = b->CreateICmpEQ(
674       b->getInt32(0),
675       EmitCallToTargetIntrinsic(TargetIntrinsicID::kBlockIdx, {}, {}, b));
676   return b->CreateAnd(is_thread0, is_block0);
677 }
678 
IsFusedReductionOutputConsistent(const HloInstruction * inst,const HloInstruction * first_reduce)679 bool IsFusedReductionOutputConsistent(const HloInstruction* inst,
680                                       const HloInstruction* first_reduce) {
681   if (IsReductionFromOrToContiguousDimensions(*inst)) {
682     // Shapes, layouts and dimensions must be the same for all reduces
683     // inside of this fusion.
684     // TODO(tjoerg): Relax the shape constraint. The datatype does not matter.
685     return ShapeUtil::Equal(first_reduce->shape(), inst->shape()) &&
686            ShapeUtil::Equal(first_reduce->operand(0)->shape(),
687                             inst->operand(0)->shape()) &&
688            ShapeUtil::Equal(first_reduce->operand(1)->shape(),
689                             inst->operand(1)->shape()) &&
690            first_reduce->dimensions() == inst->dimensions();
691   }
692   return ShapeUtil::CompatibleIgnoringElementType(
693              first_reduce->operand(0)->shape(), inst->shape()) &&
694          LayoutUtil::Equal(first_reduce->operand(0)->shape().layout(),
695                            inst->shape().layout());
696 }
697 
IsFusedReductionOutputConsistent(mlir::mhlo::ReduceOp inst,mlir::mhlo::ReduceOp first_reduce,const FusionLayoutAnalysis & layout_analysis)698 bool IsFusedReductionOutputConsistent(
699     mlir::mhlo::ReduceOp inst, mlir::mhlo::ReduceOp first_reduce,
700     const FusionLayoutAnalysis& layout_analysis) {
701   CHECK_EQ(1, first_reduce.getNumResults());
702   Shape first_reduce_operand_shape =
703       layout_analysis.GetShape(first_reduce.inputs()[0]);
704   CHECK_EQ(1, inst.getNumResults());
705   Shape inst_shape = layout_analysis.GetShape(inst.getResult(0));
706 
707   if (IsReductionFromOrToContiguousDimensions(inst, layout_analysis)) {
708     Shape first_reduce_shape =
709         layout_analysis.GetShape(first_reduce.getResult(0));
710     Shape first_reduce_init_shape =
711         layout_analysis.GetShape(first_reduce.init_values()[0]);
712 
713     Shape inst_operand_shape = layout_analysis.GetShape(inst.inputs()[0]);
714     Shape inst_init_shape = layout_analysis.GetShape(inst.init_values()[0]);
715 
716     // Shapes, layouts and dimensions must be the same for all reduces
717     // inside of this fusion.
718     // TODO(tjoerg): Relax the shape constraint. The datatype does not matter.
719     if (!(ShapeUtil::Equal(first_reduce_shape, inst_shape) &&
720           ShapeUtil::Equal(first_reduce_operand_shape, inst_operand_shape) &&
721           ShapeUtil::Equal(first_reduce_init_shape, inst_init_shape) &&
722           absl::c_equal(first_reduce.dimensions(), inst.dimensions()))) {
723       return false;
724     }
725   } else {
726     if (!(ShapeUtil::CompatibleIgnoringElementType(first_reduce_operand_shape,
727                                                    inst_shape) &&
728           LayoutUtil::Equal(first_reduce_operand_shape.layout(),
729                             inst_shape.layout()))) {
730       return false;
731     }
732   }
733   return true;
734 }
735 
736 // Given an LMHLO op, returns the operand index of the first output operand.
737 //
738 // Notice that an operand alised to an output isn't an output, even though in
739 // that case WritesMlirBuffer() returns true on that operand.
740 //
741 // An operand is !WritesMlirBuffer() || equals (aliases) to a later operand. An
742 // output is the opposite, being both WritesMlirBuffer() and does not equal to
743 // any later operand.
PartitionLmhloOperandsAndOutputs(mlir::Operation * op)744 int PartitionLmhloOperandsAndOutputs(mlir::Operation* op) {
745   CHECK(op->getDialect() == op->getContext()->getLoadedDialect("lmhlo"));
746 
747   int i;
748   for (i = op->getOperands().size() - 1; i >= 0; i--) {
749     const bool aliased =
750         std::find(op->getOperands().begin() + i + 1, op->getOperands().end(),
751                   op->getOperand(i)) != op->getOperands().end();
752     if (!WritesMlirBuffer(op, op->getOperand(i)) || aliased) {
753       break;
754     }
755   }
756   return i + 1;
757 }
758 
GetHloOperands(mlir::Operation * op)759 std::vector<mlir::Value> GetHloOperands(mlir::Operation* op) {
760   if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
761     return ToStdVector(fusion.getInputBuffers());
762   }
763   if (op->getDialect() == op->getContext()->getLoadedDialect("lmhlo")) {
764     int output_start = PartitionLmhloOperandsAndOutputs(op);
765     std::vector<mlir::Value> operands;
766     operands.reserve(output_start);
767     for (int i = 0; i < output_start; i++) {
768       operands.push_back(op->getOperand(i));
769     }
770     return operands;
771   }
772   if (op->getDialect() == op->getContext()->getLoadedDialect("mhlo")) {
773     return std::vector<mlir::Value>(op->getOperands().begin(),
774                                     op->getOperands().end());
775   }
776   LOG(FATAL) << "Unexpected op: " << MlirToString(op);
777 }
778 
GetHloOutputs(mlir::Operation * op)779 std::vector<mlir::Value> GetHloOutputs(mlir::Operation* op) {
780   if (auto fusion = mlir::dyn_cast<mlir::lmhlo::FusionOp>(op)) {
781     return ToStdVector(fusion.getOutputBuffers());
782   }
783   if (op->getDialect() == op->getContext()->getLoadedDialect("lmhlo")) {
784     int output_start = PartitionLmhloOperandsAndOutputs(op);
785     std::vector<mlir::Value> outputs;
786     for (int i = output_start; i < op->getNumOperands(); i++) {
787       outputs.push_back(op->getOperand(i));
788     }
789     return outputs;
790   }
791   if (op->getDialect() == op->getContext()->getLoadedDialect("mhlo")) {
792     return std::vector<mlir::Value>(op->getResults().begin(),
793                                     op->getResults().end());
794   }
795   LOG(FATAL) << "Unexpected op: " << MlirToString(op);
796 }
797 
WritesMlirBuffer(mlir::Operation * op,mlir::Value operand)798 bool WritesMlirBuffer(mlir::Operation* op, mlir::Value operand) {
799   llvm::SmallVector<mlir::MemoryEffects::EffectInstance, 2> effects;
800   mlir::cast<mlir::MemoryEffectOpInterface>(op).getEffectsOnValue(operand,
801                                                                   effects);
802   return absl::c_any_of(
803       effects, [](const mlir::MemoryEffects::EffectInstance& instance) {
804         return mlir::isa<mlir::MemoryEffects::Write>(instance.getEffect());
805       });
806 }
807 
GetMemRefSizeInBytes(mlir::MemRefType type)808 static int64_t GetMemRefSizeInBytes(mlir::MemRefType type) {
809   // For i1 memrefs, the underlying allocation is 8 bits.
810   if (type.getElementType().isInteger(/*width=*/1)) {
811     return type.getNumElements();
812   } else {
813     return type.getSizeInBits() / CHAR_BIT;
814   }
815 }
816 
GetAllocationIndex(mlir::BlockArgument func_arg,std::string * constant_name)817 static int64_t GetAllocationIndex(mlir::BlockArgument func_arg,
818                                   std::string* constant_name) {
819   auto func_op =
820       mlir::cast<mlir::FuncOp>(func_arg.getParentRegion()->getParentOp());
821   if (constant_name) {
822     if (auto constant_name_attr = func_op.getArgAttrOfType<mlir::StringAttr>(
823             func_arg.getArgNumber(), "lmhlo.constant_name")) {
824       *constant_name = constant_name_attr.getValue().str();
825     }
826   }
827   return func_arg.getArgNumber();
828 }
829 
GetAllocationSlice(mlir::Value v,absl::Span<const BufferAllocation> allocations,std::string * constant_name)830 StatusOr<BufferAllocation::Slice> GetAllocationSlice(
831     mlir::Value v, absl::Span<const BufferAllocation> allocations,
832     std::string* constant_name) {
833   if (constant_name) {
834     constant_name->clear();
835   }
836 
837   int64_t size = GetMemRefSizeInBytes(v.getType().cast<mlir::MemRefType>());
838 
839   // We match the following patterns here:
840   //  base := ViewOp(arg) | get_global_memref (global_memref) | arg
841   //  root := base | MemRefReinterpretCastOp(base)
842 
843   if (auto cast = mlir::dyn_cast_or_null<mlir::memref::ReinterpretCastOp>(
844           v.getDefiningOp())) {
845     v = cast.getViewSource();
846   }
847   if (auto view =
848           mlir::dyn_cast_or_null<mlir::memref::ViewOp>(v.getDefiningOp())) {
849     TF_RET_CHECK(view.source().isa<mlir::BlockArgument>());
850 
851     return BufferAllocation::Slice(
852         &allocations[GetAllocationIndex(
853             view.source().cast<mlir::BlockArgument>(), constant_name)],
854         mlir::cast<mlir::ConstantOp>(view.byte_shift().getDefiningOp())
855             .value()
856             .cast<mlir::IntegerAttr>()
857             .getValue()
858             .getSExtValue(),
859         size);
860   }
861   if (auto get_global = mlir::dyn_cast_or_null<mlir::memref::GetGlobalOp>(
862           v.getDefiningOp())) {
863     auto module = get_global->getParentOfType<mlir::ModuleOp>();
864     if (constant_name) {
865       *constant_name = get_global.name().str();
866     }
867     auto global = mlir::cast<mlir::memref::GlobalOp>(
868         module.lookupSymbol(get_global.name()));
869     int64_t index =
870         global->getAttrOfType<mlir::IntegerAttr>("lmhlo.alloc").getInt();
871     return BufferAllocation::Slice(&allocations[index], 0,
872                                    allocations[index].size());
873   }
874   if (auto arg = v.dyn_cast<mlir::BlockArgument>()) {
875     return BufferAllocation::Slice(
876         &allocations[GetAllocationIndex(arg, constant_name)], 0, size);
877   }
878 
879   return Unimplemented(
880       "Operand has to be in the form of ViewOp(arg) or "
881       "StaticMemRefCastOp(ViewOp(arg)) or arg");
882 }
883 
CanEmitFusedDynamicUpdateSliceInPlaceForGpu(mlir::lmhlo::FusionOp fusion,absl::Span<const BufferAllocation> allocations)884 bool CanEmitFusedDynamicUpdateSliceInPlaceForGpu(
885     mlir::lmhlo::FusionOp fusion,
886     absl::Span<const BufferAllocation> allocations) {
887   auto results = fusion.getFusionResults();
888   if (results.size() != 1) {
889     return false;
890   }
891   auto dus = mlir::dyn_cast<mlir::mhlo::DynamicUpdateSliceOp>(
892       results[0].getDefiningOp());
893   if (!dus) {
894     return false;
895   }
896 
897   auto output_buffers = fusion.getOutputBuffers();
898   CHECK_EQ(1, output_buffers.size());
899   auto parameter =
900       mlir::dyn_cast<mlir::memref::TensorLoadOp>(dus.operand().getDefiningOp());
901 
902   if (!parameter) {
903     return false;
904   }
905 
906   auto maybe_lhs = GetAllocationSlice(parameter.memref(), allocations);
907   auto maybe_rhs = GetAllocationSlice(output_buffers[0], allocations);
908   return maybe_lhs.ok() && maybe_rhs.ok() && *maybe_lhs == *maybe_rhs;
909 }
910 
GetShape(mlir::Value value)911 Shape GetShape(mlir::Value value) {
912   if (value.getType().isa<mlir::MemRefType>()) {
913     return TypeToShape(value.getType());
914   } else if (value.getType().isa<mlir::TensorType>()) {
915     return GetShapeFromTensorType(value);
916   } else if (value.getType().isa<mlir::TupleType>()) {
917     return TypeToShape(value.getType());
918   }
919   LOG(FATAL) << "Unexpected value type to get shape for";
920   return {};
921 }
922 
ReductionIsRaceFree(const ReductionDimensions & reduction_dimensions,const std::array<int64_t,3> & reduction_tiling)923 bool ReductionIsRaceFree(const ReductionDimensions& reduction_dimensions,
924                          const std::array<int64_t, 3>& reduction_tiling) {
925   return (reduction_dimensions.is_row_reduction &&
926           reduction_dimensions.dimensions[2] <=
927               kMinThreadsXRowReduction * reduction_tiling[2] &&
928           reduction_dimensions.dimensions[0] <=
929               kBatchedReductionRaceFreeBound) ||
930          (!reduction_dimensions.is_row_reduction &&
931           reduction_dimensions.dimensions[1] <=
932               kWarpSize * reduction_tiling[1]);
933 }
934 
935 }  // namespace gpu
936 }  // namespace xla
937