1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for lowering LHLO dialect to GPU dialect.
17
18 #include <cstdint>
19
20 #include "llvm/ADT/ArrayRef.h"
21 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
22 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
23 #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
24 #include "mlir/Dialect/Affine/IR/AffineOps.h"
25 #include "mlir/Dialect/GPU/GPUDialect.h"
26 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
27 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
28 #include "mlir/Dialect/MemRef/IR/MemRef.h"
29 #include "mlir/Dialect/SCF/SCF.h"
30 #include "mlir/Dialect/StandardOps/IR/Ops.h"
31 #include "mlir/IR/Attributes.h"
32 #include "mlir/IR/BlockAndValueMapping.h"
33 #include "mlir/IR/Builders.h"
34 #include "mlir/IR/BuiltinOps.h"
35 #include "mlir/IR/BuiltinTypes.h"
36 #include "mlir/IR/Location.h"
37 #include "mlir/IR/MLIRContext.h"
38 #include "mlir/IR/Operation.h"
39 #include "mlir/IR/PatternMatch.h"
40 #include "mlir/Pass/Pass.h"
41 #include "mlir/Transforms/DialectConversion.h"
42
43 namespace mlir {
44 namespace lmhlo {
45 namespace {
46
47 // A simple translation of LHLO reduce operations to a corresponding gpu
48 // launch operation. The transformation does no tiling and also only supports
49 // 1d results.
50 class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
51 public:
52 using OpConversionPattern::OpConversionPattern;
53
matchAndRewrite(ReduceOp reduce_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const54 LogicalResult matchAndRewrite(
55 ReduceOp reduce_op, ArrayRef<Value> args,
56 ConversionPatternRewriter& rewriter) const final {
57 auto loc = reduce_op.getLoc();
58 // Only support 1d reductions for now.
59 int64_t size = 0;
60 for (auto result : reduce_op.out()) {
61 auto shaped_type = result.getType().dyn_cast<ShapedType>();
62 if (!shaped_type || shaped_type.getRank() != 1) {
63 return failure();
64 }
65 auto dim_size = shaped_type.getDimSize(0);
66 if (size && size != dim_size) {
67 return failure();
68 }
69 size = dim_size;
70 }
71
72 auto reducing_dimension = *reduce_op.dimensions().int_value_begin();
73
74 // Require all inputs to have the same shape.
75 int64_t reduce_dim_size = 0;
76 for (auto input : reduce_op.inputs()) {
77 auto shaped_type = input.getType().dyn_cast<ShapedType>();
78 if (!shaped_type || !shaped_type.hasStaticShape()) {
79 return failure();
80 }
81 reduce_dim_size =
82 shaped_type.getDimSize(reducing_dimension.getSExtValue());
83 }
84
85 // Create a launch that is parallel in the result dimension.
86 auto block_size_x = rewriter.create<mlir::ConstantOp>(
87 loc, rewriter.getIndexType(),
88 rewriter.getIntegerAttr(rewriter.getIndexType(), size));
89 auto one = rewriter.create<mlir::ConstantOp>(
90 loc, rewriter.getIndexType(),
91 rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
92 auto launch_op = rewriter.create<mlir::gpu::LaunchOp>(
93 loc, one, one, one, block_size_x, one, one);
94 {
95 OpBuilder::InsertionGuard guard(rewriter);
96 rewriter.setInsertionPointToEnd(&launch_op.body().front());
97 auto index = launch_op.getThreadIds().x;
98
99 // Load the initial value and store it to the output.
100 for (auto pair : llvm::zip(reduce_op.init_values(), reduce_op.out())) {
101 auto init_value =
102 rewriter.create<mlir::memref::LoadOp>(loc, std::get<0>(pair));
103 rewriter.create<mlir::memref::StoreOp>(
104 loc, init_value, std::get<1>(pair), ArrayRef<Value>{index});
105 }
106
107 // Insert a loop into the body to compute the reduction. The loop ranges
108 // from [0.dim).
109 auto zero = rewriter.create<mlir::ConstantOp>(
110 loc, rewriter.getIndexType(),
111 rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
112 // TODO(b/137624192) Use dimOp to make it shape independent.
113 auto upper = rewriter.create<mlir::ConstantOp>(
114 loc, rewriter.getIndexType(),
115 rewriter.getIntegerAttr(rewriter.getIndexType(), reduce_dim_size));
116 auto step = rewriter.create<mlir::ConstantOp>(
117 loc, rewriter.getIndexType(),
118 rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
119 auto loop = rewriter.create<mlir::scf::ForOp>(loc, zero, upper, step);
120
121 rewriter.setInsertionPointToStart(loop.getBody());
122 // Compute memrefs for the value to reduce. This makes it easier to just
123 // inline the body.
124 auto output = *reduce_op.out().begin();
125 auto resType = MemRefType::get(
126 llvm::None, getElementTypeOrSelf(output.getType()),
127 makeStridedLinearLayoutMap(llvm::None,
128 MemRefType::getDynamicStrideOrOffset(),
129 rewriter.getContext()));
130 OpFoldResult offset = launch_op.getThreadIds().x;
131 auto oneAttr = rewriter.getI64IntegerAttr(1);
132 OpFoldResult size = oneAttr;
133 OpFoldResult stride = oneAttr;
134 auto accumulator = rewriter.create<memref::SubViewOp>(
135 loc, resType, output, offset, size, stride);
136 llvm::SmallVector<Value, 4> indexings;
137 Value input_buffer = reduce_op.inputs().front();
138 auto input_type_rank =
139 input_buffer.getType().cast<MemRefType>().getRank();
140
141 Value input = *reduce_op.operand_begin();
142 SmallVector<OpFoldResult> offsets = llvm::to_vector<4>(llvm::map_range(
143 llvm::seq<int>(0, input_type_rank), [&](int dim) -> OpFoldResult {
144 return dim == reducing_dimension ? loop.getInductionVar()
145 : launch_op.getThreadIds().x;
146 }));
147 SmallVector<OpFoldResult> sizes(input_type_rank, oneAttr);
148 SmallVector<OpFoldResult> strides(input_type_rank, oneAttr);
149 auto rhs = rewriter.create<memref::SubViewOp>(
150 loc, accumulator.getType(), input, offsets, sizes, strides);
151
152 // Now copy over the actual body of the reduction, leaving out the
153 // terminator.
154 BlockAndValueMapping mapping;
155 mapping.map(reduce_op.body().getArgument(0), accumulator);
156 mapping.map(reduce_op.body().getArgument(1), rhs);
157 mapping.map(reduce_op.body().getArgument(2), accumulator);
158 for (auto& nested : reduce_op.body().front().without_terminator()) {
159 auto clone = rewriter.clone(nested, mapping);
160 for (auto pair : llvm::zip(nested.getResults(), clone->getResults())) {
161 mapping.map(std::get<0>(pair), std::get<1>(pair));
162 }
163 }
164
165 // Finally, insert the terminator for the launchOp.
166 rewriter.setInsertionPointToEnd(&launch_op.body().front());
167 rewriter.create<mlir::gpu::TerminatorOp>(loc);
168 }
169
170 rewriter.eraseOp(reduce_op);
171 return success();
172 };
173 };
174
175 struct LhloLegalizeToGpuPass
176 : public LhloLegalizeToGpuPassBase<LhloLegalizeToGpuPass> {
getDependentDialectsmlir::lmhlo::__anon972849c40111::LhloLegalizeToGpuPass177 void getDependentDialects(DialectRegistry& registry) const override {
178 registry.insert<AffineDialect, gpu::GPUDialect, linalg::LinalgDialect,
179 memref::MemRefDialect, scf::SCFDialect>();
180 }
181
runOnFunctionmlir::lmhlo::__anon972849c40111::LhloLegalizeToGpuPass182 void runOnFunction() override {
183 OwningRewritePatternList patterns(&getContext());
184 ConversionTarget target(getContext());
185 target.addLegalDialect<linalg::LinalgDialect, memref::MemRefDialect,
186 StandardOpsDialect, gpu::GPUDialect, scf::SCFDialect,
187 LmhloDialect>();
188 target.addIllegalOp<ReduceOp>();
189 auto func = getFunction();
190 patterns.insert<LhloReduceToGPULaunchConverter>(func.getContext());
191 if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
192 signalPassFailure();
193 }
194 }
195 };
196
197 } // namespace
198
createLegalizeToGpuPass()199 std::unique_ptr<FunctionPass> createLegalizeToGpuPass() {
200 return std::make_unique<LhloLegalizeToGpuPass>();
201 }
202
203 } // namespace lmhlo
204 } // namespace mlir
205