1 //===- AllReduceLowering.cpp - Implementation of all-reduce lowering ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements in-dialect lowering of the all-reduce op to a block of
10 // simpler instructions.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Dialect/GPU/GPUDialect.h"
15 #include "mlir/Dialect/GPU/Passes.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/IR/BlockAndValueMapping.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Pass/Pass.h"
21
22 using namespace mlir;
23
24 namespace {
25
26 struct GpuAllReduceRewriter {
27 using AccumulatorFactory = std::function<Value(Value, Value)>;
28
GpuAllReduceRewriter__anon6314b7f60111::GpuAllReduceRewriter29 GpuAllReduceRewriter(gpu::GPUFuncOp funcOp_, gpu::AllReduceOp reduceOp_,
30 PatternRewriter &rewriter_)
31 : funcOp(funcOp_), reduceOp(reduceOp_), rewriter(rewriter_),
32 loc(reduceOp.getLoc()), valueType(reduceOp.value().getType()),
33 indexType(IndexType::get(reduceOp.getContext())),
34 int32Type(IntegerType::get(/*width=*/32, reduceOp.getContext())) {}
35
36 /// Creates an all_reduce across the workgroup.
37 ///
38 /// First reduce the elements within a subgroup. The first invocation of each
39 /// subgroup writes the intermediate result to workgroup memory. After
40 /// synchronizing the workgroup, the first subgroup reduces the values from
41 /// workgroup memory. The result is broadcasted to all invocations through
42 /// workgroup memory.
43 ///
44 /// %subgroup_reduce = `createSubgroupReduce(%operand)`
45 /// cond_br %is_first_lane, ^then1, ^continue1
46 /// ^then1:
47 /// store %subgroup_reduce, %workgroup_buffer[%subgroup_id]
48 /// br ^continue1
49 /// ^continue1:
50 /// gpu.barrier
51 /// %is_valid_subgroup = cmpi "slt" %invocation_idx, %num_subgroups
52 /// cond_br %is_valid_subgroup, ^then2, ^continue2
53 /// ^then2:
54 /// %partial_reduce = load %workgroup_buffer[%invocation_idx]
55 /// %all_reduce = `createSubgroupReduce(%partial_reduce)`
56 /// store %all_reduce, %workgroup_buffer[%zero]
57 /// llvm.br ^continue2
58 /// ^continue2:
59 /// gpu.barrier
60 /// %result = load %workgroup_buffer[%zero]
61 /// return %result
62 ///
rewrite__anon6314b7f60111::GpuAllReduceRewriter63 void rewrite() {
64 rewriter.setInsertionPoint(reduceOp);
65
66 // Compute linear invocation index and workgroup size.
67 Value dimX = getDimOp<gpu::BlockDimOp>("x");
68 Value dimY = getDimOp<gpu::BlockDimOp>("y");
69 Value dimZ = getDimOp<gpu::BlockDimOp>("z");
70 Value tidX = getDimOp<gpu::ThreadIdOp>("x");
71 Value tidY = getDimOp<gpu::ThreadIdOp>("y");
72 Value tidZ = getDimOp<gpu::ThreadIdOp>("z");
73 Value tmp1 = create<MulIOp>(int32Type, tidZ, dimY);
74 Value tmp2 = create<AddIOp>(int32Type, tmp1, tidY);
75 Value tmp3 = create<MulIOp>(int32Type, tmp2, dimX);
76 Value tmp4 = create<MulIOp>(int32Type, dimX, dimY);
77 Value invocationIdx = create<AddIOp>(int32Type, tmp3, tidX);
78 Value workgroupSize = create<MulIOp>(int32Type, tmp4, dimZ);
79
80 // Compute lane id (invocation id withing the subgroup).
81 Value subgroupMask = create<ConstantIntOp>(kSubgroupSize - 1, int32Type);
82 Value laneId = create<AndOp>(invocationIdx, subgroupMask);
83 Value isFirstLane = create<CmpIOp>(CmpIPredicate::eq, laneId,
84 create<ConstantIntOp>(0, int32Type));
85
86 Value numThreadsWithSmallerSubgroupId =
87 create<SubIOp>(invocationIdx, laneId);
88 // The number of active invocations starting from the current subgroup.
89 // The consumers do not require the value to be clamped to the size of the
90 // subgroup.
91 Value activeWidth =
92 create<SubIOp>(workgroupSize, numThreadsWithSmallerSubgroupId);
93
94 // Create factory for op which accumulates to values.
95 AccumulatorFactory accumFactory = getFactory();
96 assert(accumFactory && "failed to create accumulator factory");
97
98 // Reduce elements within each subgroup to produce the intermediate results.
99 Value subgroupReduce = createSubgroupReduce(activeWidth, laneId,
100 reduceOp.value(), accumFactory);
101
102 // Add workgroup buffer to parent function for intermediate result.
103 Value buffer = createWorkgroupBuffer();
104
105 // Write the intermediate results to workgroup memory, using the first lane
106 // of each subgroup.
107 createPredicatedBlock(isFirstLane, [&] {
108 Value subgroupId = getDivideBySubgroupSize(invocationIdx);
109 Value index = create<IndexCastOp>(indexType, subgroupId);
110 create<StoreOp>(subgroupReduce, buffer, index);
111 });
112 create<gpu::BarrierOp>();
113
114 // Compute number of active subgroups.
115 Value biasedBlockSize =
116 create<AddIOp>(int32Type, workgroupSize, subgroupMask);
117 Value numSubgroups = getDivideBySubgroupSize(biasedBlockSize);
118 Value isValidSubgroup =
119 create<CmpIOp>(CmpIPredicate::slt, invocationIdx, numSubgroups);
120
121 // Use the first numSubgroups invocations to reduce the intermediate results
122 // from workgroup memory. The final result is written to workgroup memory
123 // again.
124 Value zero = create<ConstantIndexOp>(0);
125 createPredicatedBlock(isValidSubgroup, [&] {
126 Value index = create<IndexCastOp>(indexType, invocationIdx);
127 Value value = create<LoadOp>(valueType, buffer, index);
128 Value result =
129 createSubgroupReduce(numSubgroups, laneId, value, accumFactory);
130 create<StoreOp>(result, buffer, zero);
131 });
132
133 // Synchronize workgroup and load result from workgroup memory.
134 create<gpu::BarrierOp>();
135 Value result = create<LoadOp>(valueType, buffer, zero);
136
137 rewriter.replaceOp(reduceOp, result);
138 }
139
140 private:
141 // Shortcut to create an op from rewriter using loc as the first argument.
create__anon6314b7f60111::GpuAllReduceRewriter142 template <typename T, typename... Args> T create(Args... args) {
143 return rewriter.create<T>(loc, std::forward<Args>(args)...);
144 }
145
146 // Creates dimension op of type T, with the result casted to int32.
getDimOp__anon6314b7f60111::GpuAllReduceRewriter147 template <typename T> Value getDimOp(StringRef dimension) {
148 Value dim = create<T>(indexType, rewriter.getStringAttr(dimension));
149 return create<IndexCastOp>(int32Type, dim);
150 }
151
152 /// Adds type to funcOp's workgroup attributions.
createWorkgroupBuffer__anon6314b7f60111::GpuAllReduceRewriter153 Value createWorkgroupBuffer() {
154 int workgroupMemoryAddressSpace =
155 gpu::GPUDialect::getWorkgroupAddressSpace();
156 auto bufferType =
157 MemRefType::get({kSubgroupSize}, valueType, ArrayRef<AffineMap>{},
158 workgroupMemoryAddressSpace);
159 return funcOp.addWorkgroupAttribution(bufferType);
160 }
161
162 /// Returns an accumulator factory using either the op attribute or the body
163 /// region.
getFactory__anon6314b7f60111::GpuAllReduceRewriter164 AccumulatorFactory getFactory() {
165 auto &body = reduceOp.body();
166 if (!body.empty())
167 return getFactory(body);
168 auto opAttr = reduceOp.op();
169 if (opAttr)
170 return getFactory(*opAttr);
171 return AccumulatorFactory();
172 }
173
174 /// Returns an accumulator factory that clones the body. The body's entry
175 /// block is expected to have 2 arguments. The gpu.yield return the
176 /// accumulated value of the same type.
getFactory__anon6314b7f60111::GpuAllReduceRewriter177 AccumulatorFactory getFactory(Region &body) {
178 return AccumulatorFactory([&](Value lhs, Value rhs) {
179 Block *block = rewriter.getInsertionBlock();
180 Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint());
181
182 // Insert accumulator body between split block.
183 BlockAndValueMapping mapping;
184 mapping.map(body.getArgument(0), lhs);
185 mapping.map(body.getArgument(1), rhs);
186 rewriter.cloneRegionBefore(body, *split->getParent(),
187 split->getIterator(), mapping);
188
189 // Add branch before inserted body, into body.
190 block = block->getNextNode();
191 create<BranchOp>(block, ValueRange());
192
193 // Replace all gpu.yield ops with branch out of body.
194 for (; block != split; block = block->getNextNode()) {
195 Operation *terminator = block->getTerminator();
196 if (!isa<gpu::YieldOp>(terminator))
197 continue;
198 rewriter.setInsertionPointToEnd(block);
199 rewriter.replaceOpWithNewOp<BranchOp>(
200 terminator, split, ValueRange(terminator->getOperand(0)));
201 }
202
203 // Return accumulator result.
204 rewriter.setInsertionPointToStart(split);
205 return split->addArgument(lhs.getType());
206 });
207 }
208
209 /// Returns an accumulator factory that creates an op specified by opName.
getFactory__anon6314b7f60111::GpuAllReduceRewriter210 AccumulatorFactory getFactory(StringRef opName) {
211 bool isFloatingPoint = valueType.isa<FloatType>();
212 if (opName == "add")
213 return isFloatingPoint ? getFactory<AddFOp>() : getFactory<AddIOp>();
214 if (opName == "mul")
215 return isFloatingPoint ? getFactory<MulFOp>() : getFactory<MulIOp>();
216 if (opName == "and") {
217 return getFactory<AndOp>();
218 }
219 if (opName == "or") {
220 return getFactory<OrOp>();
221 }
222 if (opName == "xor") {
223 return getFactory<XOrOp>();
224 }
225 if (opName == "max") {
226 return isFloatingPoint
227 ? getCmpFactory<CmpFOp, CmpFPredicate, CmpFPredicate::UGT>()
228 : getCmpFactory<CmpIOp, CmpIPredicate, CmpIPredicate::ugt>();
229 }
230 if (opName == "min") {
231 return isFloatingPoint
232 ? getCmpFactory<CmpFOp, CmpFPredicate, CmpFPredicate::ULT>()
233 : getCmpFactory<CmpIOp, CmpIPredicate, CmpIPredicate::ult>();
234 }
235 return AccumulatorFactory();
236 }
237
238 /// Returns an accumulator factory that creates an op of type T.
getFactory__anon6314b7f60111::GpuAllReduceRewriter239 template <typename T> AccumulatorFactory getFactory() {
240 return [&](Value lhs, Value rhs) {
241 return create<T>(lhs.getType(), lhs, rhs);
242 };
243 }
244
245 /// Returns an accumulator for comparison such as min, max. T is the type
246 /// of the compare op.
247 template <typename T, typename PredicateEnum, PredicateEnum predicate>
getCmpFactory__anon6314b7f60111::GpuAllReduceRewriter248 AccumulatorFactory getCmpFactory() const {
249 return [&](Value lhs, Value rhs) {
250 Value cmp = rewriter.create<T>(loc, predicate, lhs, rhs);
251 return rewriter.create<SelectOp>(loc, cmp, lhs, rhs);
252 };
253 }
254
255 /// Creates an if-block skeleton and calls the two factories to generate the
256 /// ops in the `then` and `else` block..
257 ///
258 /// llvm.cond_br %condition, ^then, ^continue
259 /// ^then:
260 /// %then_operands = `thenOpsFactory()`
261 /// llvm.br ^continue(%then_operands)
262 /// ^else:
263 /// %else_operands = `elseOpsFactory()`
264 /// llvm.br ^continue(%else_operands)
265 /// ^continue(%block_operands):
266 ///
267 template <typename ThenOpsFactory, typename ElseOpsFactory>
createIf__anon6314b7f60111::GpuAllReduceRewriter268 void createIf(Value condition, ThenOpsFactory &&thenOpsFactory,
269 ElseOpsFactory &&elseOpsFactory) {
270 Block *currentBlock = rewriter.getInsertionBlock();
271 auto currentPoint = rewriter.getInsertionPoint();
272
273 Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint);
274 Block *elseBlock = rewriter.splitBlock(thenBlock, thenBlock->begin());
275 Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin());
276
277 rewriter.setInsertionPointToEnd(currentBlock);
278 create<CondBranchOp>(condition, thenBlock,
279 /*trueOperands=*/ArrayRef<Value>(), elseBlock,
280 /*falseOperands=*/ArrayRef<Value>());
281
282 rewriter.setInsertionPointToStart(thenBlock);
283 auto thenOperands = thenOpsFactory();
284 create<BranchOp>(continueBlock, thenOperands);
285
286 rewriter.setInsertionPointToStart(elseBlock);
287 auto elseOperands = elseOpsFactory();
288 create<BranchOp>(continueBlock, elseOperands);
289
290 assert(thenOperands.size() == elseOperands.size());
291 rewriter.setInsertionPointToStart(continueBlock);
292 for (auto operand : thenOperands)
293 continueBlock->addArgument(operand.getType());
294 }
295
296 /// Shortcut for createIf with empty else block and no block operands.
297 template <typename Factory>
createPredicatedBlock__anon6314b7f60111::GpuAllReduceRewriter298 void createPredicatedBlock(Value condition, Factory &&predicatedOpsFactory) {
299 static_assert(std::is_same<decltype(predicatedOpsFactory()), void>::value,
300 "predicatedOpsFactory should not return any value");
301 createIf(
302 condition,
303 [&] {
304 predicatedOpsFactory();
305 return ArrayRef<Value>();
306 },
307 [&] { return ArrayRef<Value>(); });
308 }
309
310 /// Creates a reduction across the first activeWidth lanes of a subgroup, or
311 /// the entire subgroup if activeWidth is larger than the subgroup width.
312 /// The first lane returns the result, all others return values are undefined.
createSubgroupReduce__anon6314b7f60111::GpuAllReduceRewriter313 Value createSubgroupReduce(Value activeWidth, Value laneId, Value operand,
314 AccumulatorFactory &accumFactory) {
315 Value subgroupSize = create<ConstantIntOp>(kSubgroupSize, int32Type);
316 Value isPartialSubgroup =
317 create<CmpIOp>(CmpIPredicate::slt, activeWidth, subgroupSize);
318 std::array<Type, 2> shuffleType = {valueType, rewriter.getI1Type()};
319 auto xorAttr = rewriter.getStringAttr("xor");
320
321 createIf(
322 isPartialSubgroup,
323 // Generate reduction over a (potentially) partial subgroup.
324 [&] {
325 Value value = operand;
326 // Repeatedly shuffle value from 'laneId ^ i' and accumulate if source
327 // lane is within the active range. The accumulated value is available
328 // in the first lane.
329 for (int i = 1; i < kSubgroupSize; i <<= 1) {
330 Value offset = create<ConstantIntOp>(i, int32Type);
331 auto shuffleOp = create<gpu::ShuffleOp>(shuffleType, value, offset,
332 activeWidth, xorAttr);
333 // Skip the accumulation if the shuffle op read from a lane outside
334 // of the active range.
335 createIf(
336 shuffleOp.getResult(1),
337 [&] {
338 return SmallVector<Value, 1>{
339 accumFactory(value, shuffleOp.getResult(0))};
340 },
341 [&] { return llvm::makeArrayRef(value); });
342 value = rewriter.getInsertionBlock()->getArgument(0);
343 }
344 return SmallVector<Value, 1>{value};
345 },
346 // Generate a reduction over the entire subgroup. This is a
347 // specialization of the above reduction with unconditional
348 // accumulation.
349 [&] {
350 Value value = operand;
351 for (int i = 1; i < kSubgroupSize; i <<= 1) {
352 Value offset = create<ConstantIntOp>(i, int32Type);
353 auto shuffleOp = create<gpu::ShuffleOp>(shuffleType, value, offset,
354 subgroupSize, xorAttr);
355 value = accumFactory(value, shuffleOp.getResult(0));
356 }
357 return SmallVector<Value, 1>{value};
358 });
359 return rewriter.getInsertionBlock()->getArgument(0);
360 }
361
362 /// Returns value divided by the subgroup size (i.e. 32).
getDivideBySubgroupSize__anon6314b7f60111::GpuAllReduceRewriter363 Value getDivideBySubgroupSize(Value value) {
364 Value subgroupSize = create<ConstantIntOp>(kSubgroupSize, int32Type);
365 return create<SignedDivIOp>(int32Type, value, subgroupSize);
366 }
367
368 gpu::GPUFuncOp funcOp;
369 gpu::AllReduceOp reduceOp;
370 PatternRewriter &rewriter;
371
372 Location loc;
373 Type valueType;
374 Type indexType;
375 Type int32Type;
376
377 static constexpr int kSubgroupSize = 32;
378 };
379
380 struct GpuAllReduceConversion : public RewritePattern {
GpuAllReduceConversion__anon6314b7f60111::GpuAllReduceConversion381 explicit GpuAllReduceConversion(MLIRContext *context)
382 : RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {}
383
matchAndRewrite__anon6314b7f60111::GpuAllReduceConversion384 LogicalResult matchAndRewrite(Operation *op,
385 PatternRewriter &rewriter) const override {
386 auto funcOp = cast<gpu::GPUFuncOp>(op);
387 auto callback = [&](gpu::AllReduceOp reduceOp) {
388 GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite();
389 // Performing a rewrite invalidates the walk iterator. Report interrupt
390 // so that we can start a new walk until all all_reduce ops are replaced.
391 return WalkResult::interrupt();
392 };
393 while (funcOp.walk(callback).wasInterrupted()) {
394 }
395 return success();
396 }
397 };
398 } // namespace
399
populateGpuAllReducePatterns(MLIRContext * context,OwningRewritePatternList & patterns)400 void mlir::populateGpuAllReducePatterns(MLIRContext *context,
401 OwningRewritePatternList &patterns) {
402 patterns.insert<GpuAllReduceConversion>(context);
403 }
404