1 //===- Utils.cpp - Utilities to support the Linalg dialect ----------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for the Linalg dialect.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/Linalg/Utils/Utils.h"
14
15 #include "mlir/Dialect/Affine/EDSC/Intrinsics.h"
16 #include "mlir/Dialect/Affine/IR/AffineOps.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
18 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
19 #include "mlir/Dialect/SCF/EDSC/Builders.h"
20 #include "mlir/Dialect/SCF/SCF.h"
21 #include "mlir/Dialect/StandardOps/IR/Ops.h"
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/AffineMap.h"
24 #include "mlir/IR/Matchers.h"
25 #include "mlir/IR/OpImplementation.h"
26 #include "mlir/Pass/Pass.h"
27 #include "mlir/Transforms/LoopUtils.h"
28
29 using namespace mlir;
30 using namespace mlir::linalg;
31 using namespace mlir::scf;
32
33 Optional<RegionMatcher::BinaryOpKind>
matchAsScalarBinaryOp(GenericOp op)34 RegionMatcher::matchAsScalarBinaryOp(GenericOp op) {
35 auto ®ion = op.region();
36 if (!llvm::hasSingleElement(region))
37 return llvm::None;
38
39 Block &block = region.front();
40 if (block.getNumArguments() != 2 ||
41 !block.getArgument(0).getType().isSignlessIntOrFloat() ||
42 !block.getArgument(1).getType().isSignlessIntOrFloat())
43 return llvm::None;
44
45 auto &ops = block.getOperations();
46 if (!llvm::hasSingleElement(block.without_terminator()))
47 return llvm::None;
48
49 using mlir::matchers::m_Val;
50 auto a = m_Val(block.getArgument(0));
51 auto b = m_Val(block.getArgument(1));
52
53 auto addPattern = m_Op<linalg::YieldOp>(m_Op<AddIOp>(a, b));
54 if (addPattern.match(&ops.back()))
55 return BinaryOpKind::IAdd;
56
57 return llvm::None;
58 }
59
isParallelIteratorType(Attribute attr)60 bool mlir::linalg::isParallelIteratorType(Attribute attr) {
61 if (auto strAttr = attr.dyn_cast<StringAttr>()) {
62 return strAttr.getValue() == getParallelIteratorTypeName();
63 }
64 return false;
65 }
66
isReductionIteratorType(Attribute attr)67 bool mlir::linalg::isReductionIteratorType(Attribute attr) {
68 if (auto strAttr = attr.dyn_cast<StringAttr>()) {
69 return strAttr.getValue() == getReductionIteratorTypeName();
70 }
71 return false;
72 }
73
isWindowIteratorType(Attribute attr)74 bool mlir::linalg::isWindowIteratorType(Attribute attr) {
75 if (auto strAttr = attr.dyn_cast<StringAttr>()) {
76 return strAttr.getValue() == getWindowIteratorTypeName();
77 }
78 return false;
79 }
80
81 /// Explicit instantiation of loop nest generator for different loop types.
82 template struct mlir::linalg::GenerateLoopNest<scf::ForOp>;
83 template struct mlir::linalg::GenerateLoopNest<scf::ParallelOp>;
84 template struct mlir::linalg::GenerateLoopNest<AffineForOp>;
85
86 /// Given a list of subview ranges, extract individual values for lower, upper
87 /// bounds and steps and put them into the corresponding vectors.
unpackRanges(ArrayRef<Range> ranges,SmallVectorImpl<Value> & lbs,SmallVectorImpl<Value> & ubs,SmallVectorImpl<Value> & steps)88 static void unpackRanges(ArrayRef<Range> ranges, SmallVectorImpl<Value> &lbs,
89 SmallVectorImpl<Value> &ubs,
90 SmallVectorImpl<Value> &steps) {
91 for (Range range : ranges) {
92 lbs.emplace_back(range.offset);
93 ubs.emplace_back(range.size);
94 steps.emplace_back(range.stride);
95 }
96 }
97
98 namespace mlir {
99 namespace linalg {
100
getStaticShape(LinalgOp linalgOp)101 SmallVector<int64_t, 8> getStaticShape(LinalgOp linalgOp) {
102 SmallVector<int64_t, 8> res;
103 for (Value v : linalgOp.getShapedOperands()) {
104 auto shape = v.getType().cast<ShapedType>().getShape();
105 res.append(shape.begin(), shape.end());
106 }
107 if (linalgOp.getNumInitTensors())
108 return res;
109 for (Value v : linalgOp.getOperation()->getResults()) {
110 auto shape = v.getType().cast<ShapedType>().getShape();
111 res.append(shape.begin(), shape.end());
112 }
113 return res;
114 }
115
getStaticLoopRanges(LinalgOp linalgOp)116 Optional<SmallVector<int64_t, 4>> getStaticLoopRanges(LinalgOp linalgOp) {
117 SmallVector<int64_t, 8> viewSizes = getStaticShape(linalgOp);
118 AffineMap invertedMap = linalgOp.getShapesToLoopsMap();
119 if (!invertedMap)
120 return {};
121 return invertedMap.compose(viewSizes);
122 }
123
124 /// Specialization to build an scf "for" nest.
125 template <>
doit(ArrayRef<Range> loopRanges,ValueRange iterArgInitValues,ArrayRef<Attribute> iteratorTypes,function_ref<scf::ValueVector (ValueRange,ValueRange)> bodyBuilderFn,Optional<LinalgLoopDistributionOptions> distributionOptions)126 void GenerateLoopNest<scf::ForOp>::doit(
127 ArrayRef<Range> loopRanges, ValueRange iterArgInitValues,
128 ArrayRef<Attribute> iteratorTypes,
129 function_ref<scf::ValueVector(ValueRange, ValueRange)> bodyBuilderFn,
130 Optional<LinalgLoopDistributionOptions> distributionOptions) {
131 // Create procInfo so it dominates loops, if appropriate.
132 OpBuilder &builder = edsc::ScopedContext::getBuilderRef();
133 Location loc = edsc::ScopedContext::getLocation();
134 SmallVector<ProcInfo, 2> procInfo;
135 if (distributionOptions.hasValue())
136 procInfo = distributionOptions->procInfo(builder, loc, loopRanges);
137
138 SmallVector<Value, 4> lbs, ubs, steps;
139 unpackRanges(loopRanges, lbs, ubs, steps);
140 LoopNest loopNest =
141 edsc::loopNestBuilder(lbs, ubs, steps, iterArgInitValues, bodyBuilderFn);
142
143 if (!distributionOptions.hasValue() || loopNest.loops.empty())
144 return;
145
146 // Only supports cyclic distribution for now.
147 for (auto it : llvm::zip(loopNest.loops, procInfo,
148 distributionOptions->distributionMethod))
149 if (std::get<2>(it) == DistributionMethod::Cyclic)
150 mapLoopToProcessorIds(std::get<0>(it), std::get<1>(it).procId,
151 std::get<1>(it).nprocs);
152 }
153
154 /// Specialization to build affine "for" nest.
155 template <>
doit(ArrayRef<Range> loopRanges,ValueRange iterArgInitValues,ArrayRef<Attribute> iteratorTypes,function_ref<scf::ValueVector (ValueRange,ValueRange)> bodyBuilderFn,Optional<LinalgLoopDistributionOptions>)156 void GenerateLoopNest<AffineForOp>::doit(
157 ArrayRef<Range> loopRanges, ValueRange iterArgInitValues,
158 ArrayRef<Attribute> iteratorTypes,
159 function_ref<scf::ValueVector(ValueRange, ValueRange)> bodyBuilderFn,
160 Optional<LinalgLoopDistributionOptions>) {
161 assert(iterArgInitValues.empty() && "unexpected AffineForOp init values");
162 SmallVector<Value, 4> lbs, ubs, steps;
163 unpackRanges(loopRanges, lbs, ubs, steps);
164
165 // Affine loops require constant steps.
166 SmallVector<int64_t, 4> constantSteps;
167 constantSteps.reserve(steps.size());
168 for (Value v : steps) {
169 auto op = v.getDefiningOp<ConstantIndexOp>();
170 assert(op && "Affine loops require constant steps");
171 constantSteps.push_back(op.getValue());
172 }
173
174 auto bodyBuilderWithoutIterArgsFn = [&](ValueRange ivs) {
175 bodyBuilderFn(ivs, {});
176 };
177 edsc::affineLoopNestBuilder(lbs, ubs, constantSteps,
178 bodyBuilderWithoutIterArgsFn);
179 }
180
181 /// Update the `lb`, `ub` and `step` to get per processor `lb`, `ub` and `step`.
updateBoundsForCyclicDistribution(OpBuilder & builder,Location loc,Value procId,Value nprocs,Value & lb,Value & ub,Value & step)182 static void updateBoundsForCyclicDistribution(OpBuilder &builder, Location loc,
183 Value procId, Value nprocs,
184 Value &lb, Value &ub,
185 Value &step) {
186 using edsc::op::operator+;
187 using edsc::op::operator*;
188 lb = lb + (procId * step);
189 step = nprocs * step;
190 }
191
192 /// Generates a loop nest consisting of scf.parallel and scf.for, depending
193 /// on the `iteratorTypes.` Consecutive parallel loops create a single
194 /// scf.parallel operation; each sequential loop creates a new scf.for
195 /// operation. The body of the innermost loop is populated by
196 /// `bodyBuilderFn` that accepts a range of induction variables for all
197 /// loops. `ivStorage` is used to store the partial list of induction
198 /// variables.
199 // TODO: this function can be made iterative instead. However, it
200 // will have at most as many recursive calls as nested loops, which rarely
201 // exceeds 10.
202 static void
generateParallelLoopNest(ValueRange lbs,ValueRange ubs,ValueRange steps,ArrayRef<Attribute> iteratorTypes,function_ref<void (ValueRange)> bodyBuilderFn,SmallVectorImpl<Value> & ivStorage,ArrayRef<DistributionMethod> distributionMethod={})203 generateParallelLoopNest(ValueRange lbs, ValueRange ubs, ValueRange steps,
204 ArrayRef<Attribute> iteratorTypes,
205 function_ref<void(ValueRange)> bodyBuilderFn,
206 SmallVectorImpl<Value> &ivStorage,
207 ArrayRef<DistributionMethod> distributionMethod = {}) {
208 assert(lbs.size() == ubs.size());
209 assert(lbs.size() == steps.size());
210 assert(lbs.size() == iteratorTypes.size());
211
212 // If there are no (more) loops to be generated, generate the body and be
213 // done with it.
214 if (iteratorTypes.empty())
215 return bodyBuilderFn(ivStorage);
216
217 // Find the outermost parallel loops and drop their types from the list.
218 unsigned nLoops = iteratorTypes.size();
219 unsigned nOuterPar =
220 nLoops - iteratorTypes.drop_while(isParallelIteratorType).size();
221
222 // If there are no outer parallel loops, generate one sequential loop and
223 // recurse. Note that we wouldn't have dropped anything from `iteratorTypes`
224 // in this case.
225 if (nOuterPar == 0) {
__anon7ecdf6bc0202(Value iv) 226 edsc::loopNestBuilder(lbs[0], ubs[0], steps[0], [&](Value iv) {
227 ivStorage.push_back(iv);
228 generateParallelLoopNest(lbs.drop_front(), ubs.drop_front(),
229 steps.drop_front(), iteratorTypes.drop_front(),
230 bodyBuilderFn, ivStorage, distributionMethod);
231 });
232 return;
233 }
234 if (distributionMethod.empty()) {
235 // Generate a single parallel loop-nest operation for all outermost
236 // parallel loops and recurse.
237 edsc::OperationBuilder<scf::ParallelOp>(
238 lbs.take_front(nOuterPar), ubs.take_front(nOuterPar),
239 steps.take_front(nOuterPar),
__anon7ecdf6bc0302(OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) 240 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) {
241 edsc::ScopedContext context(nestedBuilder, nestedLoc);
242 ivStorage.append(localIvs.begin(), localIvs.end());
243 generateParallelLoopNest(
244 lbs.drop_front(nOuterPar), ubs.drop_front(nOuterPar),
245 steps.drop_front(nOuterPar), iteratorTypes.drop_front(nOuterPar),
246 bodyBuilderFn, ivStorage,
247 (distributionMethod.size() < nOuterPar)
248 ? ArrayRef<DistributionMethod>()
249 : distributionMethod.drop_front(nOuterPar));
250 });
251 return;
252 }
253
254 // Process all consecutive similarly distributed loops simultaneously.
255 DistributionMethod methodToUse = distributionMethod[0];
256 unsigned numProcessed = 1;
257 for (unsigned i = 1; i < nOuterPar && i < distributionMethod.size(); ++i) {
258 if (distributionMethod[i] != methodToUse)
259 break;
260 numProcessed++;
261 }
262
263 switch (methodToUse) {
264 case DistributionMethod::Cyclic: {
265 // Generate a single parallel loop-nest operation for all outermost
266 // parallel loops and recurse.
267 edsc::OperationBuilder<scf::ParallelOp>(
268 lbs.take_front(numProcessed), ubs.take_front(numProcessed),
269 steps.take_front(numProcessed),
__anon7ecdf6bc0402(OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) 270 [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange localIvs) {
271 edsc::ScopedContext context(nestedBuilder, nestedLoc);
272 ivStorage.append(localIvs.begin(), localIvs.end());
273 generateParallelLoopNest(
274 lbs.drop_front(numProcessed), ubs.drop_front(numProcessed),
275 steps.drop_front(numProcessed),
276 iteratorTypes.drop_front(numProcessed), bodyBuilderFn, ivStorage,
277 (distributionMethod.size() < numProcessed)
278 ? ArrayRef<DistributionMethod>()
279 : distributionMethod.drop_front(numProcessed));
280 });
281 return;
282 }
283 case DistributionMethod::CyclicNumProcsGeNumIters: {
284 // Check (for the processed loops) that the iteration is in-bounds.
285 using edsc::op::slt;
286 using edsc::op::operator&&;
287 Value cond = slt(lbs[0], ubs[0]);
288 for (unsigned i = 1; i < numProcessed; ++i)
289 cond = cond && slt(lbs[i], ubs[i]);
290 ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed));
__anon7ecdf6bc0502() 291 edsc::conditionBuilder(cond, [&]() {
292 generateParallelLoopNest(
293 lbs.drop_front(numProcessed), ubs.drop_front(numProcessed),
294 steps.drop_front(numProcessed),
295 iteratorTypes.drop_front(numProcessed), bodyBuilderFn, ivStorage,
296 distributionMethod.drop_front(numProcessed));
297 });
298 return;
299 }
300 case DistributionMethod::CyclicNumProcsEqNumIters:
301 // No check/loops needed here. Set the `%iv` to be the `%lb` and proceed
302 // with inner loop generation.
303 ivStorage.append(lbs.begin(), std::next(lbs.begin(), numProcessed));
304 generateParallelLoopNest(
305 lbs.drop_front(numProcessed), ubs.drop_front(numProcessed),
306 steps.drop_front(numProcessed), iteratorTypes.drop_front(numProcessed),
307 bodyBuilderFn, ivStorage, distributionMethod.drop_front(numProcessed));
308 return;
309 }
310 }
311
312 /// Specialization for generating a mix of parallel and sequential scf loops.
313 template <>
doit(ArrayRef<Range> loopRanges,ValueRange iterArgInitValues,ArrayRef<Attribute> iteratorTypes,function_ref<scf::ValueVector (ValueRange,ValueRange)> bodyBuilderFn,Optional<LinalgLoopDistributionOptions> distributionOptions)314 void GenerateLoopNest<scf::ParallelOp>::doit(
315 ArrayRef<Range> loopRanges, ValueRange iterArgInitValues,
316 ArrayRef<Attribute> iteratorTypes,
317 function_ref<scf::ValueVector(ValueRange, ValueRange)> bodyBuilderFn,
318 Optional<LinalgLoopDistributionOptions> distributionOptions) {
319 assert(iterArgInitValues.empty() && "unexpected ParallelOp init values");
320 // This function may be passed more iterator types than ranges.
321 assert(iteratorTypes.size() >= loopRanges.size() &&
322 "expected iterator type for all ranges");
323 iteratorTypes = iteratorTypes.take_front(loopRanges.size());
324 SmallVector<Value, 8> lbsStorage, ubsStorage, stepsStorage, ivs;
325 unsigned numLoops = iteratorTypes.size();
326 ivs.reserve(numLoops);
327 lbsStorage.reserve(numLoops);
328 ubsStorage.reserve(numLoops);
329 stepsStorage.reserve(numLoops);
330
331 // Get the loop lb, ub, and step.
332 unpackRanges(loopRanges, lbsStorage, ubsStorage, stepsStorage);
333
334 // Modify the lb, ub, and step based on the distribution options.
335 SmallVector<DistributionMethod, 0> distributionMethod;
336 if (distributionOptions) {
337 auto &options = distributionOptions.getValue();
338 OpBuilder &builder = edsc::ScopedContext::getBuilderRef();
339 Location loc = edsc::ScopedContext::getLocation();
340 distributionMethod.assign(distributionOptions->distributionMethod.begin(),
341 distributionOptions->distributionMethod.end());
342 SmallVector<Range, 2> parallelLoopRanges;
343 for (auto iteratorType : enumerate(iteratorTypes)) {
344 if (isParallelIteratorType(iteratorType.value()))
345 parallelLoopRanges.push_back(loopRanges[iteratorType.index()]);
346 }
347 if (distributionMethod.size() < parallelLoopRanges.size())
348 parallelLoopRanges.resize(distributionMethod.size());
349 SmallVector<ProcInfo, 2> procInfo =
350 options.procInfo(builder, loc, parallelLoopRanges);
351 unsigned index = 0;
352 for (auto iteratorType : enumerate(iteratorTypes)) {
353 if (index >= procInfo.size())
354 break;
355 if (isParallelIteratorType(iteratorType.value())) {
356 unsigned i = iteratorType.index();
357 updateBoundsForCyclicDistribution(builder, loc, procInfo[index].procId,
358 procInfo[index].nprocs, lbsStorage[i],
359 ubsStorage[i], stepsStorage[i]);
360 index++;
361 }
362 }
363 }
364 ValueRange lbs(lbsStorage), ubs(ubsStorage), steps(stepsStorage);
365 auto bodyBuilderWithoutIterArgsFn = [&](ValueRange ivs) {
366 bodyBuilderFn(ivs, {});
367 };
368 generateParallelLoopNest(lbs, ubs, steps, iteratorTypes,
369 bodyBuilderWithoutIterArgsFn, ivs,
370 distributionMethod);
371
372 assert(ivs.size() == iteratorTypes.size() && "did not generate enough loops");
373 }
374
375 } // namespace linalg
376 } // namespace mlir
377