• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for lowering LHLO dialect to GPU dialect.
17 
18 #include <cstdint>
19 
20 #include "llvm/ADT/ArrayRef.h"
21 #include "mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h"
22 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
23 #include "mlir-hlo/Dialect/mhlo/transforms/map_lmhlo_to_scalar_op.h"
24 #include "mlir/Dialect/Affine/IR/AffineOps.h"
25 #include "mlir/Dialect/GPU/GPUDialect.h"
26 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
27 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
28 #include "mlir/Dialect/MemRef/IR/MemRef.h"
29 #include "mlir/Dialect/SCF/SCF.h"
30 #include "mlir/Dialect/StandardOps/IR/Ops.h"
31 #include "mlir/IR/Attributes.h"
32 #include "mlir/IR/BlockAndValueMapping.h"
33 #include "mlir/IR/Builders.h"
34 #include "mlir/IR/BuiltinOps.h"
35 #include "mlir/IR/BuiltinTypes.h"
36 #include "mlir/IR/Location.h"
37 #include "mlir/IR/MLIRContext.h"
38 #include "mlir/IR/Operation.h"
39 #include "mlir/IR/PatternMatch.h"
40 #include "mlir/Pass/Pass.h"
41 #include "mlir/Transforms/DialectConversion.h"
42 
43 namespace mlir {
44 namespace lmhlo {
45 namespace {
46 
47 // A simple translation of LHLO reduce operations to a corresponding gpu
48 // launch operation. The transformation does no tiling and also only supports
49 // 1d results.
50 class LhloReduceToGPULaunchConverter : public OpConversionPattern<ReduceOp> {
51  public:
52   using OpConversionPattern::OpConversionPattern;
53 
matchAndRewrite(ReduceOp reduce_op,ArrayRef<Value> args,ConversionPatternRewriter & rewriter) const54   LogicalResult matchAndRewrite(
55       ReduceOp reduce_op, ArrayRef<Value> args,
56       ConversionPatternRewriter& rewriter) const final {
57     auto loc = reduce_op.getLoc();
58     // Only support 1d reductions for now.
59     int64_t size = 0;
60     for (auto result : reduce_op.out()) {
61       auto shaped_type = result.getType().dyn_cast<ShapedType>();
62       if (!shaped_type || shaped_type.getRank() != 1) {
63         return failure();
64       }
65       auto dim_size = shaped_type.getDimSize(0);
66       if (size && size != dim_size) {
67         return failure();
68       }
69       size = dim_size;
70     }
71 
72     auto reducing_dimension = *reduce_op.dimensions().int_value_begin();
73 
74     // Require all inputs to have the same shape.
75     int64_t reduce_dim_size = 0;
76     for (auto input : reduce_op.inputs()) {
77       auto shaped_type = input.getType().dyn_cast<ShapedType>();
78       if (!shaped_type || !shaped_type.hasStaticShape()) {
79         return failure();
80       }
81       reduce_dim_size =
82           shaped_type.getDimSize(reducing_dimension.getSExtValue());
83     }
84 
85     // Create a launch that is parallel in the result dimension.
86     auto block_size_x = rewriter.create<mlir::ConstantOp>(
87         loc, rewriter.getIndexType(),
88         rewriter.getIntegerAttr(rewriter.getIndexType(), size));
89     auto one = rewriter.create<mlir::ConstantOp>(
90         loc, rewriter.getIndexType(),
91         rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
92     auto launch_op = rewriter.create<mlir::gpu::LaunchOp>(
93         loc, one, one, one, block_size_x, one, one);
94     {
95       OpBuilder::InsertionGuard guard(rewriter);
96       rewriter.setInsertionPointToEnd(&launch_op.body().front());
97       auto index = launch_op.getThreadIds().x;
98 
99       // Load the initial value and store it to the output.
100       for (auto pair : llvm::zip(reduce_op.init_values(), reduce_op.out())) {
101         auto init_value =
102             rewriter.create<mlir::memref::LoadOp>(loc, std::get<0>(pair));
103         rewriter.create<mlir::memref::StoreOp>(
104             loc, init_value, std::get<1>(pair), ArrayRef<Value>{index});
105       }
106 
107       // Insert a loop into the body to compute the reduction. The loop ranges
108       // from [0.dim).
109       auto zero = rewriter.create<mlir::ConstantOp>(
110           loc, rewriter.getIndexType(),
111           rewriter.getIntegerAttr(rewriter.getIndexType(), 0));
112       // TODO(b/137624192) Use dimOp to make it shape independent.
113       auto upper = rewriter.create<mlir::ConstantOp>(
114           loc, rewriter.getIndexType(),
115           rewriter.getIntegerAttr(rewriter.getIndexType(), reduce_dim_size));
116       auto step = rewriter.create<mlir::ConstantOp>(
117           loc, rewriter.getIndexType(),
118           rewriter.getIntegerAttr(rewriter.getIndexType(), 1));
119       auto loop = rewriter.create<mlir::scf::ForOp>(loc, zero, upper, step);
120 
121       rewriter.setInsertionPointToStart(loop.getBody());
122       // Compute memrefs for the value to reduce. This makes it easier to just
123       // inline the body.
124       auto output = *reduce_op.out().begin();
125       auto resType = MemRefType::get(
126           llvm::None, getElementTypeOrSelf(output.getType()),
127           makeStridedLinearLayoutMap(llvm::None,
128                                      MemRefType::getDynamicStrideOrOffset(),
129                                      rewriter.getContext()));
130       OpFoldResult offset = launch_op.getThreadIds().x;
131       auto oneAttr = rewriter.getI64IntegerAttr(1);
132       OpFoldResult size = oneAttr;
133       OpFoldResult stride = oneAttr;
134       auto accumulator = rewriter.create<memref::SubViewOp>(
135           loc, resType, output, offset, size, stride);
136       llvm::SmallVector<Value, 4> indexings;
137       Value input_buffer = reduce_op.inputs().front();
138       auto input_type_rank =
139           input_buffer.getType().cast<MemRefType>().getRank();
140 
141       Value input = *reduce_op.operand_begin();
142       SmallVector<OpFoldResult> offsets = llvm::to_vector<4>(llvm::map_range(
143           llvm::seq<int>(0, input_type_rank), [&](int dim) -> OpFoldResult {
144             return dim == reducing_dimension ? loop.getInductionVar()
145                                              : launch_op.getThreadIds().x;
146           }));
147       SmallVector<OpFoldResult> sizes(input_type_rank, oneAttr);
148       SmallVector<OpFoldResult> strides(input_type_rank, oneAttr);
149       auto rhs = rewriter.create<memref::SubViewOp>(
150           loc, accumulator.getType(), input, offsets, sizes, strides);
151 
152       // Now copy over the actual body of the reduction, leaving out the
153       // terminator.
154       BlockAndValueMapping mapping;
155       mapping.map(reduce_op.body().getArgument(0), accumulator);
156       mapping.map(reduce_op.body().getArgument(1), rhs);
157       mapping.map(reduce_op.body().getArgument(2), accumulator);
158       for (auto& nested : reduce_op.body().front().without_terminator()) {
159         auto clone = rewriter.clone(nested, mapping);
160         for (auto pair : llvm::zip(nested.getResults(), clone->getResults())) {
161           mapping.map(std::get<0>(pair), std::get<1>(pair));
162         }
163       }
164 
165       // Finally, insert the terminator for the launchOp.
166       rewriter.setInsertionPointToEnd(&launch_op.body().front());
167       rewriter.create<mlir::gpu::TerminatorOp>(loc);
168     }
169 
170     rewriter.eraseOp(reduce_op);
171     return success();
172   };
173 };
174 
175 struct LhloLegalizeToGpuPass
176     : public LhloLegalizeToGpuPassBase<LhloLegalizeToGpuPass> {
getDependentDialectsmlir::lmhlo::__anon972849c40111::LhloLegalizeToGpuPass177   void getDependentDialects(DialectRegistry& registry) const override {
178     registry.insert<AffineDialect, gpu::GPUDialect, linalg::LinalgDialect,
179                     memref::MemRefDialect, scf::SCFDialect>();
180   }
181 
runOnFunctionmlir::lmhlo::__anon972849c40111::LhloLegalizeToGpuPass182   void runOnFunction() override {
183     OwningRewritePatternList patterns(&getContext());
184     ConversionTarget target(getContext());
185     target.addLegalDialect<linalg::LinalgDialect, memref::MemRefDialect,
186                            StandardOpsDialect, gpu::GPUDialect, scf::SCFDialect,
187                            LmhloDialect>();
188     target.addIllegalOp<ReduceOp>();
189     auto func = getFunction();
190     patterns.insert<LhloReduceToGPULaunchConverter>(func.getContext());
191     if (failed(applyPartialConversion(func, target, std::move(patterns)))) {
192       signalPassFailure();
193     }
194   }
195 };
196 
197 }  // namespace
198 
createLegalizeToGpuPass()199 std::unique_ptr<FunctionPass> createLegalizeToGpuPass() {
200   return std::make_unique<LhloLegalizeToGpuPass>();
201 }
202 
203 }  // namespace lmhlo
204 }  // namespace mlir
205