• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- AllReduceLowering.cpp - Implementation of all-reduce lowering ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements in-dialect lowering of the all-reduce op to a block of
10 // simpler instructions.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Dialect/GPU/GPUDialect.h"
15 #include "mlir/Dialect/GPU/Passes.h"
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/IR/BlockAndValueMapping.h"
18 #include "mlir/IR/Builders.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Pass/Pass.h"
21 
22 using namespace mlir;
23 
24 namespace {
25 
26 struct GpuAllReduceRewriter {
27   using AccumulatorFactory = std::function<Value(Value, Value)>;
28 
GpuAllReduceRewriter__anon6314b7f60111::GpuAllReduceRewriter29   GpuAllReduceRewriter(gpu::GPUFuncOp funcOp_, gpu::AllReduceOp reduceOp_,
30                        PatternRewriter &rewriter_)
31       : funcOp(funcOp_), reduceOp(reduceOp_), rewriter(rewriter_),
32         loc(reduceOp.getLoc()), valueType(reduceOp.value().getType()),
33         indexType(IndexType::get(reduceOp.getContext())),
34         int32Type(IntegerType::get(/*width=*/32, reduceOp.getContext())) {}
35 
36   /// Creates an all_reduce across the workgroup.
37   ///
38   /// First reduce the elements within a subgroup. The first invocation of each
39   /// subgroup writes the intermediate result to workgroup memory. After
40   /// synchronizing the workgroup, the first subgroup reduces the values from
41   /// workgroup memory. The result is broadcasted to all invocations through
42   /// workgroup memory.
43   ///
44   ///     %subgroup_reduce = `createSubgroupReduce(%operand)`
45   ///     cond_br %is_first_lane, ^then1, ^continue1
46   ///   ^then1:
47   ///     store %subgroup_reduce, %workgroup_buffer[%subgroup_id]
48   ///     br ^continue1
49   ///   ^continue1:
50   ///     gpu.barrier
51   ///     %is_valid_subgroup = cmpi "slt" %invocation_idx, %num_subgroups
52   ///     cond_br %is_valid_subgroup, ^then2, ^continue2
53   ///   ^then2:
54   ///     %partial_reduce = load %workgroup_buffer[%invocation_idx]
55   ///     %all_reduce = `createSubgroupReduce(%partial_reduce)`
56   ///     store %all_reduce, %workgroup_buffer[%zero]
57   ///     llvm.br ^continue2
58   ///   ^continue2:
59   ///     gpu.barrier
60   ///     %result = load %workgroup_buffer[%zero]
61   ///     return %result
62   ///
rewrite__anon6314b7f60111::GpuAllReduceRewriter63   void rewrite() {
64     rewriter.setInsertionPoint(reduceOp);
65 
66     // Compute linear invocation index and workgroup size.
67     Value dimX = getDimOp<gpu::BlockDimOp>("x");
68     Value dimY = getDimOp<gpu::BlockDimOp>("y");
69     Value dimZ = getDimOp<gpu::BlockDimOp>("z");
70     Value tidX = getDimOp<gpu::ThreadIdOp>("x");
71     Value tidY = getDimOp<gpu::ThreadIdOp>("y");
72     Value tidZ = getDimOp<gpu::ThreadIdOp>("z");
73     Value tmp1 = create<MulIOp>(int32Type, tidZ, dimY);
74     Value tmp2 = create<AddIOp>(int32Type, tmp1, tidY);
75     Value tmp3 = create<MulIOp>(int32Type, tmp2, dimX);
76     Value tmp4 = create<MulIOp>(int32Type, dimX, dimY);
77     Value invocationIdx = create<AddIOp>(int32Type, tmp3, tidX);
78     Value workgroupSize = create<MulIOp>(int32Type, tmp4, dimZ);
79 
80     // Compute lane id (invocation id withing the subgroup).
81     Value subgroupMask = create<ConstantIntOp>(kSubgroupSize - 1, int32Type);
82     Value laneId = create<AndOp>(invocationIdx, subgroupMask);
83     Value isFirstLane = create<CmpIOp>(CmpIPredicate::eq, laneId,
84                                        create<ConstantIntOp>(0, int32Type));
85 
86     Value numThreadsWithSmallerSubgroupId =
87         create<SubIOp>(invocationIdx, laneId);
88     // The number of active invocations starting from the current subgroup.
89     // The consumers do not require the value to be clamped to the size of the
90     // subgroup.
91     Value activeWidth =
92         create<SubIOp>(workgroupSize, numThreadsWithSmallerSubgroupId);
93 
94     // Create factory for op which accumulates to values.
95     AccumulatorFactory accumFactory = getFactory();
96     assert(accumFactory && "failed to create accumulator factory");
97 
98     // Reduce elements within each subgroup to produce the intermediate results.
99     Value subgroupReduce = createSubgroupReduce(activeWidth, laneId,
100                                                 reduceOp.value(), accumFactory);
101 
102     // Add workgroup buffer to parent function for intermediate result.
103     Value buffer = createWorkgroupBuffer();
104 
105     // Write the intermediate results to workgroup memory, using the first lane
106     // of each subgroup.
107     createPredicatedBlock(isFirstLane, [&] {
108       Value subgroupId = getDivideBySubgroupSize(invocationIdx);
109       Value index = create<IndexCastOp>(indexType, subgroupId);
110       create<StoreOp>(subgroupReduce, buffer, index);
111     });
112     create<gpu::BarrierOp>();
113 
114     // Compute number of active subgroups.
115     Value biasedBlockSize =
116         create<AddIOp>(int32Type, workgroupSize, subgroupMask);
117     Value numSubgroups = getDivideBySubgroupSize(biasedBlockSize);
118     Value isValidSubgroup =
119         create<CmpIOp>(CmpIPredicate::slt, invocationIdx, numSubgroups);
120 
121     // Use the first numSubgroups invocations to reduce the intermediate results
122     // from workgroup memory. The final result is written to workgroup memory
123     // again.
124     Value zero = create<ConstantIndexOp>(0);
125     createPredicatedBlock(isValidSubgroup, [&] {
126       Value index = create<IndexCastOp>(indexType, invocationIdx);
127       Value value = create<LoadOp>(valueType, buffer, index);
128       Value result =
129           createSubgroupReduce(numSubgroups, laneId, value, accumFactory);
130       create<StoreOp>(result, buffer, zero);
131     });
132 
133     // Synchronize workgroup and load result from workgroup memory.
134     create<gpu::BarrierOp>();
135     Value result = create<LoadOp>(valueType, buffer, zero);
136 
137     rewriter.replaceOp(reduceOp, result);
138   }
139 
140 private:
141   // Shortcut to create an op from rewriter using loc as the first argument.
create__anon6314b7f60111::GpuAllReduceRewriter142   template <typename T, typename... Args> T create(Args... args) {
143     return rewriter.create<T>(loc, std::forward<Args>(args)...);
144   }
145 
146   // Creates dimension op of type T, with the result casted to int32.
getDimOp__anon6314b7f60111::GpuAllReduceRewriter147   template <typename T> Value getDimOp(StringRef dimension) {
148     Value dim = create<T>(indexType, rewriter.getStringAttr(dimension));
149     return create<IndexCastOp>(int32Type, dim);
150   }
151 
152   /// Adds type to funcOp's workgroup attributions.
createWorkgroupBuffer__anon6314b7f60111::GpuAllReduceRewriter153   Value createWorkgroupBuffer() {
154     int workgroupMemoryAddressSpace =
155         gpu::GPUDialect::getWorkgroupAddressSpace();
156     auto bufferType =
157         MemRefType::get({kSubgroupSize}, valueType, ArrayRef<AffineMap>{},
158                         workgroupMemoryAddressSpace);
159     return funcOp.addWorkgroupAttribution(bufferType);
160   }
161 
162   /// Returns an accumulator factory using either the op attribute or the body
163   /// region.
getFactory__anon6314b7f60111::GpuAllReduceRewriter164   AccumulatorFactory getFactory() {
165     auto &body = reduceOp.body();
166     if (!body.empty())
167       return getFactory(body);
168     auto opAttr = reduceOp.op();
169     if (opAttr)
170       return getFactory(*opAttr);
171     return AccumulatorFactory();
172   }
173 
174   /// Returns an accumulator factory that clones the body. The body's entry
175   /// block is expected to have 2 arguments. The gpu.yield return the
176   /// accumulated value of the same type.
getFactory__anon6314b7f60111::GpuAllReduceRewriter177   AccumulatorFactory getFactory(Region &body) {
178     return AccumulatorFactory([&](Value lhs, Value rhs) {
179       Block *block = rewriter.getInsertionBlock();
180       Block *split = rewriter.splitBlock(block, rewriter.getInsertionPoint());
181 
182       // Insert accumulator body between split block.
183       BlockAndValueMapping mapping;
184       mapping.map(body.getArgument(0), lhs);
185       mapping.map(body.getArgument(1), rhs);
186       rewriter.cloneRegionBefore(body, *split->getParent(),
187                                  split->getIterator(), mapping);
188 
189       // Add branch before inserted body, into body.
190       block = block->getNextNode();
191       create<BranchOp>(block, ValueRange());
192 
193       // Replace all gpu.yield ops with branch out of body.
194       for (; block != split; block = block->getNextNode()) {
195         Operation *terminator = block->getTerminator();
196         if (!isa<gpu::YieldOp>(terminator))
197           continue;
198         rewriter.setInsertionPointToEnd(block);
199         rewriter.replaceOpWithNewOp<BranchOp>(
200             terminator, split, ValueRange(terminator->getOperand(0)));
201       }
202 
203       // Return accumulator result.
204       rewriter.setInsertionPointToStart(split);
205       return split->addArgument(lhs.getType());
206     });
207   }
208 
209   /// Returns an accumulator factory that creates an op specified by opName.
getFactory__anon6314b7f60111::GpuAllReduceRewriter210   AccumulatorFactory getFactory(StringRef opName) {
211     bool isFloatingPoint = valueType.isa<FloatType>();
212     if (opName == "add")
213       return isFloatingPoint ? getFactory<AddFOp>() : getFactory<AddIOp>();
214     if (opName == "mul")
215       return isFloatingPoint ? getFactory<MulFOp>() : getFactory<MulIOp>();
216     if (opName == "and") {
217       return getFactory<AndOp>();
218     }
219     if (opName == "or") {
220       return getFactory<OrOp>();
221     }
222     if (opName == "xor") {
223       return getFactory<XOrOp>();
224     }
225     if (opName == "max") {
226       return isFloatingPoint
227                  ? getCmpFactory<CmpFOp, CmpFPredicate, CmpFPredicate::UGT>()
228                  : getCmpFactory<CmpIOp, CmpIPredicate, CmpIPredicate::ugt>();
229     }
230     if (opName == "min") {
231       return isFloatingPoint
232                  ? getCmpFactory<CmpFOp, CmpFPredicate, CmpFPredicate::ULT>()
233                  : getCmpFactory<CmpIOp, CmpIPredicate, CmpIPredicate::ult>();
234     }
235     return AccumulatorFactory();
236   }
237 
238   /// Returns an accumulator factory that creates an op of type T.
getFactory__anon6314b7f60111::GpuAllReduceRewriter239   template <typename T> AccumulatorFactory getFactory() {
240     return [&](Value lhs, Value rhs) {
241       return create<T>(lhs.getType(), lhs, rhs);
242     };
243   }
244 
245   /// Returns an accumulator for comparison such as min, max. T is the type
246   /// of the compare op.
247   template <typename T, typename PredicateEnum, PredicateEnum predicate>
getCmpFactory__anon6314b7f60111::GpuAllReduceRewriter248   AccumulatorFactory getCmpFactory() const {
249     return [&](Value lhs, Value rhs) {
250       Value cmp = rewriter.create<T>(loc, predicate, lhs, rhs);
251       return rewriter.create<SelectOp>(loc, cmp, lhs, rhs);
252     };
253   }
254 
255   /// Creates an if-block skeleton and calls the two factories to generate the
256   /// ops in the `then` and `else` block..
257   ///
258   ///     llvm.cond_br %condition, ^then, ^continue
259   ///   ^then:
260   ///     %then_operands = `thenOpsFactory()`
261   ///     llvm.br ^continue(%then_operands)
262   ///   ^else:
263   ///     %else_operands = `elseOpsFactory()`
264   ///     llvm.br ^continue(%else_operands)
265   ///   ^continue(%block_operands):
266   ///
267   template <typename ThenOpsFactory, typename ElseOpsFactory>
createIf__anon6314b7f60111::GpuAllReduceRewriter268   void createIf(Value condition, ThenOpsFactory &&thenOpsFactory,
269                 ElseOpsFactory &&elseOpsFactory) {
270     Block *currentBlock = rewriter.getInsertionBlock();
271     auto currentPoint = rewriter.getInsertionPoint();
272 
273     Block *thenBlock = rewriter.splitBlock(currentBlock, currentPoint);
274     Block *elseBlock = rewriter.splitBlock(thenBlock, thenBlock->begin());
275     Block *continueBlock = rewriter.splitBlock(elseBlock, elseBlock->begin());
276 
277     rewriter.setInsertionPointToEnd(currentBlock);
278     create<CondBranchOp>(condition, thenBlock,
279                          /*trueOperands=*/ArrayRef<Value>(), elseBlock,
280                          /*falseOperands=*/ArrayRef<Value>());
281 
282     rewriter.setInsertionPointToStart(thenBlock);
283     auto thenOperands = thenOpsFactory();
284     create<BranchOp>(continueBlock, thenOperands);
285 
286     rewriter.setInsertionPointToStart(elseBlock);
287     auto elseOperands = elseOpsFactory();
288     create<BranchOp>(continueBlock, elseOperands);
289 
290     assert(thenOperands.size() == elseOperands.size());
291     rewriter.setInsertionPointToStart(continueBlock);
292     for (auto operand : thenOperands)
293       continueBlock->addArgument(operand.getType());
294   }
295 
296   /// Shortcut for createIf with empty else block and no block operands.
297   template <typename Factory>
createPredicatedBlock__anon6314b7f60111::GpuAllReduceRewriter298   void createPredicatedBlock(Value condition, Factory &&predicatedOpsFactory) {
299     static_assert(std::is_same<decltype(predicatedOpsFactory()), void>::value,
300                   "predicatedOpsFactory should not return any value");
301     createIf(
302         condition,
303         [&] {
304           predicatedOpsFactory();
305           return ArrayRef<Value>();
306         },
307         [&] { return ArrayRef<Value>(); });
308   }
309 
310   /// Creates a reduction across the first activeWidth lanes of a subgroup, or
311   /// the entire subgroup if activeWidth is larger than the subgroup width.
312   /// The first lane returns the result, all others return values are undefined.
createSubgroupReduce__anon6314b7f60111::GpuAllReduceRewriter313   Value createSubgroupReduce(Value activeWidth, Value laneId, Value operand,
314                              AccumulatorFactory &accumFactory) {
315     Value subgroupSize = create<ConstantIntOp>(kSubgroupSize, int32Type);
316     Value isPartialSubgroup =
317         create<CmpIOp>(CmpIPredicate::slt, activeWidth, subgroupSize);
318     std::array<Type, 2> shuffleType = {valueType, rewriter.getI1Type()};
319     auto xorAttr = rewriter.getStringAttr("xor");
320 
321     createIf(
322         isPartialSubgroup,
323         // Generate reduction over a (potentially) partial subgroup.
324         [&] {
325           Value value = operand;
326           // Repeatedly shuffle value from 'laneId ^ i' and accumulate if source
327           // lane is within the active range. The accumulated value is available
328           // in the first lane.
329           for (int i = 1; i < kSubgroupSize; i <<= 1) {
330             Value offset = create<ConstantIntOp>(i, int32Type);
331             auto shuffleOp = create<gpu::ShuffleOp>(shuffleType, value, offset,
332                                                     activeWidth, xorAttr);
333             // Skip the accumulation if the shuffle op read from a lane outside
334             // of the active range.
335             createIf(
336                 shuffleOp.getResult(1),
337                 [&] {
338                   return SmallVector<Value, 1>{
339                       accumFactory(value, shuffleOp.getResult(0))};
340                 },
341                 [&] { return llvm::makeArrayRef(value); });
342             value = rewriter.getInsertionBlock()->getArgument(0);
343           }
344           return SmallVector<Value, 1>{value};
345         },
346         // Generate a reduction over the entire subgroup. This is a
347         // specialization of the above reduction with unconditional
348         // accumulation.
349         [&] {
350           Value value = operand;
351           for (int i = 1; i < kSubgroupSize; i <<= 1) {
352             Value offset = create<ConstantIntOp>(i, int32Type);
353             auto shuffleOp = create<gpu::ShuffleOp>(shuffleType, value, offset,
354                                                     subgroupSize, xorAttr);
355             value = accumFactory(value, shuffleOp.getResult(0));
356           }
357           return SmallVector<Value, 1>{value};
358         });
359     return rewriter.getInsertionBlock()->getArgument(0);
360   }
361 
362   /// Returns value divided by the subgroup size (i.e. 32).
getDivideBySubgroupSize__anon6314b7f60111::GpuAllReduceRewriter363   Value getDivideBySubgroupSize(Value value) {
364     Value subgroupSize = create<ConstantIntOp>(kSubgroupSize, int32Type);
365     return create<SignedDivIOp>(int32Type, value, subgroupSize);
366   }
367 
368   gpu::GPUFuncOp funcOp;
369   gpu::AllReduceOp reduceOp;
370   PatternRewriter &rewriter;
371 
372   Location loc;
373   Type valueType;
374   Type indexType;
375   Type int32Type;
376 
377   static constexpr int kSubgroupSize = 32;
378 };
379 
380 struct GpuAllReduceConversion : public RewritePattern {
GpuAllReduceConversion__anon6314b7f60111::GpuAllReduceConversion381   explicit GpuAllReduceConversion(MLIRContext *context)
382       : RewritePattern(gpu::GPUFuncOp::getOperationName(), 1, context) {}
383 
matchAndRewrite__anon6314b7f60111::GpuAllReduceConversion384   LogicalResult matchAndRewrite(Operation *op,
385                                 PatternRewriter &rewriter) const override {
386     auto funcOp = cast<gpu::GPUFuncOp>(op);
387     auto callback = [&](gpu::AllReduceOp reduceOp) {
388       GpuAllReduceRewriter(funcOp, reduceOp, rewriter).rewrite();
389       // Performing a rewrite invalidates the walk iterator. Report interrupt
390       // so that we can start a new walk until all all_reduce ops are replaced.
391       return WalkResult::interrupt();
392     };
393     while (funcOp.walk(callback).wasInterrupted()) {
394     }
395     return success();
396   }
397 };
398 } // namespace
399 
populateGpuAllReducePatterns(MLIRContext * context,OwningRewritePatternList & patterns)400 void mlir::populateGpuAllReducePatterns(MLIRContext *context,
401                                         OwningRewritePatternList &patterns) {
402   patterns.insert<GpuAllReduceConversion>(context);
403 }
404