1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion pass.
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
18 #include "mlir/Dialect/Linalg/Passes.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/AffineMap.h"
24 #include "mlir/IR/Dominance.h"
25 #include "mlir/Support/LLVM.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 #include "llvm/ADT/MapVector.h"
28 #include "llvm/Support/CommandLine.h"
29 #include "llvm/Support/Debug.h"
30
31 #include <set>
32
33 #define DEBUG_TYPE "linalg-fusion"
34
35 using namespace mlir;
36 using namespace mlir::edsc;
37 using namespace mlir::edsc::intrinsics;
38 using namespace mlir::linalg;
39
40 using llvm::dbgs;
41
42 /// Implements a simple high-level fusion pass on linalg structured operations.
43 ///
44 /// In each block, linalg ops are processed in reverse textual order.
45 /// Given a linalg op `O`, fusion occurs by:
46 /// 1. inspecting the linalg ops that write into the views read by `O`. There
47 /// are 2 cases:
48 /// a) buffer case: use the SSA value of the views and a simple alias
49 /// analysis on subview ops to determine producer-consumer dependences;
50 /// b) tensor case: use SSA use-def chains on subtensor ops;
51 /// 2. greedily fuse the linalg ops that produce the subview/subtensor.
52 /// 3. inspect the fused ops and determine whether they have other remaining
53 /// LinalgOp uses. If not, then erase the original producing linalg op.
54 ///
55 /// More advanced use cases, analyses as well as profitability heuristics are
56 /// left for future work.
57
58 // Fill `offset`, `sizes` and `strides` used to iterate over the shape indexed
59 // by `permutationMap`.
inferShapeComponents(AffineMap permutationMap,ArrayRef<Range> loopRanges,SmallVectorImpl<Value> & offsets,SmallVectorImpl<Value> & sizes,SmallVectorImpl<Value> & strides)60 static void inferShapeComponents(AffineMap permutationMap,
61 ArrayRef<Range> loopRanges,
62 SmallVectorImpl<Value> &offsets,
63 SmallVectorImpl<Value> &sizes,
64 SmallVectorImpl<Value> &strides) {
65 assert(permutationMap.isProjectedPermutation() &&
66 "expected some subset of a permutation map");
67 SmallVector<Range, 4> shapeRanges(permutationMap.getNumResults());
68 unsigned idx = 0;
69 for (AffineExpr e : permutationMap.getResults()) {
70 // loopToOperandRangesMaps are permutations-only, just swap indices.
71 unsigned loopPos = e.cast<AffineDimExpr>().getPosition();
72 shapeRanges[idx++] = loopRanges[loopPos];
73 }
74 // Construct a new subshape for the tile.
75 unsigned rank = shapeRanges.size();
76 offsets.reserve(rank);
77 sizes.reserve(rank);
78 strides.reserve(rank);
79 for (auto r : shapeRanges) {
80 offsets.push_back(r.offset);
81 sizes.push_back(r.size);
82 strides.push_back(r.stride);
83 }
84 }
85
86 // Return a cloned version of `op` that operates on `loopRanges`, assumed to be
87 // a subset of the original loop ranges of `op`.
88 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps
89 // to the `loopRanges` in order to obtain view ranges.
cloneWithLoopRanges(OpBuilder & b,Location loc,LinalgOp op,ArrayRef<Range> loopRanges)90 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
91 ArrayRef<Range> loopRanges) {
92 SmallVector<Value, 8> clonedShapes;
93 clonedShapes.reserve(op.getNumShapedOperands());
94
95 // Iterate over the shape operands in order.
96 // Extract the subranges from the linearized ranges.
97 for (auto en : llvm::enumerate(op.getShapedOperands())) {
98 unsigned shapedOperandIdx = en.index();
99 AffineMap map = op.getIndexingMap(shapedOperandIdx);
100 LLVM_DEBUG(llvm::dbgs() << "shapedOperandIdx: " << shapedOperandIdx
101 << " with indexingMap: " << map << "\n");
102 SmallVector<Value, 4> offsets, sizes, strides;
103 inferShapeComponents(map, loopRanges, offsets, sizes, strides);
104 Value shape = en.value();
105 Value sub = shape.getType().isa<MemRefType>()
106 ? b.create<SubViewOp>(loc, shape, offsets, sizes, strides)
107 .getResult()
108 : b.create<SubTensorOp>(loc, shape, offsets, sizes, strides)
109 .getResult();
110 clonedShapes.push_back(sub);
111 }
112 // Append the other operands.
113 auto operands = op.getAssumedNonShapedOperands();
114 clonedShapes.append(operands.begin(), operands.end());
115
116 // Iterate over the results in order.
117 // Extract the subtensor type from the linearized range.
118 // Since we do not enforce any canonicalizations on the fly, this is always
119 // fully dynamic at construction time.
120 SmallVector<Type, 4> resultTypes;
121 resultTypes.reserve(op->getNumResults());
122 for (RankedTensorType t : op.getOutputTensorTypes()) {
123 unsigned rank = t.getRank();
124 SmallVector<int64_t, 4> staticOffsetsVector(
125 rank, ShapedType::kDynamicStrideOrOffset);
126 SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
127 SmallVector<int64_t, 4> staticStridesVector(
128 rank, ShapedType::kDynamicStrideOrOffset);
129 resultTypes.push_back(SubTensorOp::inferResultType(
130 t.cast<RankedTensorType>(), staticOffsetsVector, staticSizesVector,
131 staticStridesVector));
132 }
133
134 Operation *clonedOp = op.clone(b, loc, resultTypes, clonedShapes);
135 // When the producer is an IndexedGenericOp, we have to transform its block
136 // IV arguments according to the tiling of the consumer, i.e. offset them by
137 // the values computed in `loopRanges`.
138 if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(clonedOp)) {
139 auto &block = indexedGenericOp.region().front();
140 OpBuilder::InsertionGuard g(b);
141 b.setInsertionPointToStart(&block);
142 for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) {
143 Value oldIndex = block.getArgument(i);
144 // TODO: replace by an affine_apply.
145 AddIOp newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex,
146 loopRanges[i].offset);
147 oldIndex.replaceAllUsesExcept(newIndex,
148 SmallPtrSet<Operation *, 1>{newIndex});
149 }
150 }
151
152 return clonedOp;
153 }
154
155 struct ShapeDimension {
156 Value shape;
157 unsigned dimension;
158 };
159
160 // Given an `op`, returns the first (`shape`, `dimension`) pair that identifies
161 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps
162 // guarantees at least one such dimension is found. If multiple candidates exist
163 // they must agree by construction (i.e. have the same size) and we just return
164 // the first one.
165 static ShapeDimension
getShapeDefiningLoopRange(LinalgOp op,unsigned loopDepth,bool fromSubViewOpOnly=false)166 getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
167 bool fromSubViewOpOnly = false) {
168 auto maps = op.indexing_maps();
169 // Iterate over the inputs and outputs in order.
170 // Extract the subranges from the linearized ranges.
171 SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers());
172 for (auto en : llvm::enumerate(ios)) {
173 // The method `getRangeFromOperandShape` requires using SubViewOp or
174 // SubTensorOps. If the value isnt defined from there continue.
175 // todo: The method should be adapted to get the values from
176 // `ViewInterface`. The interface needs a `getOrCreateRanges` method which
177 // currently returns a `linalg.range`. The fix here is to move this op to
178 // `std` dialect and add the method to `ViewInterface`.
179 if (fromSubViewOpOnly &&
180 !isa_and_nonnull<SubViewOp, SubTensorOp>(en.value().getDefiningOp()))
181 continue;
182
183 unsigned idx = en.index();
184 auto map = maps[idx].cast<AffineMapAttr>().getValue();
185 LLVM_DEBUG(llvm::dbgs()
186 << "getShapeDefiningLoopRange I/O idx: " << idx << "\n");
187 LLVM_DEBUG(llvm::dbgs()
188 << "getShapeDefiningLoopRange map: " << map << "\n");
189 Value shape = en.value();
190 SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr);
191 for (auto en2 : llvm::enumerate(map.getResults())) {
192 auto dimExpr = en2.value().dyn_cast<AffineDimExpr>();
193 if (!dimExpr)
194 continue;
195 if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
196 LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: "
197 << loopDepth << "\n");
198 LLVM_DEBUG(llvm::dbgs()
199 << "getShapeDefiningLoopRange shape: " << shape << "\n");
200 return ShapeDimension{shape, static_cast<unsigned>(en2.index())};
201 }
202 }
203 }
204 llvm_unreachable("Expect to be able to extract a shape defining loop range");
205 }
206
207 /// Fuse the producer by cloning the `producer`. The `fusedLoopsAndRanges`
208 /// provides the loop range information for the fused loops. The rest are
209 /// obtained from the producer itself, since they are not tiled + fused.
fuse(OpBuilder & b,LinalgOp producer,const DenseMap<unsigned,Range> & fusedLoopsAndRanges)210 static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
211 const DenseMap<unsigned, Range> &fusedLoopsAndRanges) {
212
213 unsigned nPar = producer.getNumParallelLoops();
214 unsigned nRed = producer.getNumReductionLoops();
215 unsigned nWin = producer.getNumWindowLoops();
216 SmallVector<Range, 8> loopRanges(nPar + nRed + nWin);
217 for (auto fusedLoops : fusedLoopsAndRanges)
218 loopRanges[fusedLoops.first] = fusedLoops.second;
219
220 // Iterate over all dimensions. For the dimensions not identified by the
221 // producer map for `producerIdx`, we need to explicitly compute the shape
222 // that defines the loop ranges using the `producer`.
223 for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) {
224 if (loopRanges[i].offset)
225 LLVM_DEBUG(llvm::dbgs()
226 << "existing LoopRange: " << loopRanges[i] << "\n");
227 else {
228 auto shapeDim = getShapeDefiningLoopRange(producer, i);
229 loopRanges[i] = Range{std_constant_index(0),
230 std_dim(shapeDim.shape, shapeDim.dimension),
231 std_constant_index(1)};
232 LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
233 }
234 }
235
236 return cloneWithLoopRanges(b, producer.getLoc(), producer, loopRanges);
237 }
238
239 /// Get the loop range for a dimension `dim` based on the `shapedOperand`. It is
240 /// expected to be defined by a subview op or a subtensor op.
getRangeFromOperandShape(OpBuilder & b,Location loc,Value shapedOperand,unsigned dim)241 static Range getRangeFromOperandShape(OpBuilder &b, Location loc,
242 Value shapedOperand, unsigned dim) {
243 Operation *shapeProducingOp = shapedOperand.getDefiningOp();
244 if (auto subViewOp = dyn_cast<SubViewOp>(shapeProducingOp))
245 return subViewOp.getOrCreateRanges(b, loc)[dim];
246 if (auto subTensorOp = dyn_cast<SubTensorOp>(shapeProducingOp))
247 return subTensorOp.getOrCreateRanges(b, loc)[dim];
248 llvm_unreachable("SubviewOp or SubTensorOp expected");
249 }
250
251 /// Fuses the producer of `producerIdx` into the loop immediately enclosing
252 /// `consumer`. This is achieved by "recomputing" the `producer` at the time it
253 /// is needed just before the `consumer.
254 ///
255 /// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are
256 /// 2 cases:
257 /// 1. Buffer case: `producerIdx` is the index of the buffer in
258 /// `producer.getOutputBuffers()`.
259 /// 2. Tensor case: `producerIdx` is the index of the tensor in
260 /// `producer.getResults()`.
fuse(OpBuilder & b,LinalgOp producer,unsigned producerIdx,LinalgOp consumer,unsigned consumerIdx)261 static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
262 LinalgOp consumer, unsigned consumerIdx) {
263 AffineMap producerMap = producer.getOutputIndexingMap(producerIdx);
264 LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx
265 << ", producer map: " << producerMap << "\n");
266 DenseMap<unsigned, Range> fusedLoopsAndRanges;
267 Location loc = consumer.getLoc();
268 Value shapedOperand = consumer.getShapedOperand(consumerIdx);
269 for (auto en : llvm::enumerate(producerMap.getResults())) {
270 unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
271 fusedLoopsAndRanges[posInProducerLoop] =
272 getRangeFromOperandShape(b, loc, shapedOperand, en.index());
273 }
274 return fuse(b, producer, fusedLoopsAndRanges);
275 }
276
277 // Encode structural fusion safety preconditions.
278 // Some of these will be lifted in the future with better analysis.
isStructurallyFusableProducer(LinalgOp producer,Value consumedView,LinalgOp consumer)279 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView,
280 LinalgOp consumer) {
281 assert(producer.hasBufferSemantics() &&
282 "expected linalg op with buffer semantics");
283 assert(consumer.hasBufferSemantics() &&
284 "expected linalg op with buffer semantics");
285 if (producer.getNumOutputs() != 1) {
286 LLVM_DEBUG(llvm::dbgs() << "\nNot structurally fusable (multi-output)");
287 return false;
288 }
289 // Only fuse when the producer block dominates.
290 DominanceInfo dom(producer.getOperation());
291 if (!dom.dominates(producer->getBlock(), consumer->getBlock())) {
292 LLVM_DEBUG(
293 llvm::dbgs()
294 << "\nNot structurally fusable (producer block does not dominate)");
295 return false;
296 }
297 return true;
298 }
299
isProducerLastWriteOfView(const LinalgDependenceGraph & graph,LinalgOp consumer,Value consumedView,LinalgOp producer)300 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
301 LinalgOp consumer,
302 Value consumedView,
303 LinalgOp producer) {
304 assert(producer.hasBufferSemantics() &&
305 "expected linalg op with buffer semantics");
306 assert(consumer.hasBufferSemantics() &&
307 "expected linalg op with buffer semantics");
308 // Make some simple structural checks that alleviate the need for more
309 // complex analyses.
310 if (!isStructurallyFusableProducer(producer, consumedView, consumer)) {
311 LLVM_DEBUG(llvm::dbgs() << "\n***Not static last write due to structure:\t"
312 << *producer.getOperation());
313 return false;
314 }
315 // Check for any interleaved write to consumedView.
316 if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) {
317 LLVM_DEBUG(llvm::dbgs() << "\n***Not fusable due to interleaved write:\t"
318 << *producer.getOperation());
319 return false;
320 }
321 return true;
322 }
323
isFusableInto(const LinalgDependenceGraph & graph,LinalgOp consumer,Value consumedView,LinalgOp producer)324 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
325 LinalgOp consumer, Value consumedView,
326 LinalgOp producer) {
327 assert(producer.hasBufferSemantics() &&
328 "expected linalg op with buffer semantics");
329 assert(consumer.hasBufferSemantics() &&
330 "expected linalg op with buffer semantics");
331 if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer))
332 return false;
333 // Check for any fusion-preventing dependence to any shape read/written that
334 // would violate dependences.
335 if (!graph.findCoveringDependences(producer, consumer).empty()) {
336 LLVM_DEBUG(llvm::dbgs()
337 << "\n***Not fusable due to an interleaved dependence:\t"
338 << *producer.getOperation());
339 return false;
340 }
341 if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) {
342 // TODO: add a level of indirection to linalg.generic.
343 if (convOp.padding())
344 return false;
345 }
346 if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) {
347 // TODO: add a level of indirection to linalg.generic.
348 if (convOp.padding())
349 return false;
350 }
351 return true;
352 }
353
isSameSubView(Value a,Value b)354 static bool isSameSubView(Value a, Value b) {
355 if (a == b)
356 return true;
357 auto sva = a.getDefiningOp<SubViewOp>();
358 auto svb = b.getDefiningOp<SubViewOp>();
359 if (!sva || !svb)
360 return false;
361 if (!isSameSubView(sva.getViewSource(), svb.getViewSource()))
362 return false;
363 if (sva.getType() != svb.getType())
364 return false;
365 if (sva.getNumOperands() != svb.getNumOperands())
366 return false;
367 if (sva.static_offsets() != svb.static_offsets())
368 return false;
369 if (sva.static_sizes() != svb.static_sizes())
370 return false;
371 if (sva.static_strides() != svb.static_strides())
372 return false;
373 /// Skip the "source" operand.
374 for (unsigned idx = 1, e = sva.getNumOperands(); idx != e; ++idx)
375 if (sva.getOperand(idx) != svb.getOperand(idx))
376 return false;
377 return true;
378 }
379
380 static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
findFusableProducer(LinalgOp consumer,unsigned consumerIdx,const LinalgDependenceGraph & dependenceGraph)381 findFusableProducer(LinalgOp consumer, unsigned consumerIdx,
382 const LinalgDependenceGraph &dependenceGraph) {
383 // Only consider RAW and WAW atm.
384 for (auto depType : {
385 LinalgDependenceGraph::DependenceType::RAW,
386 LinalgDependenceGraph::DependenceType::WAW,
387 }) {
388 for (auto dependence : llvm::make_filter_range(
389 dependenceGraph.getDependencesInto(consumer, depType),
390 [consumerIdx](
391 LinalgDependenceGraph::LinalgDependenceGraphElem elem) {
392 return elem.indexingOpView.operandIndex == consumerIdx;
393 })) {
394 auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
395
396 // Check that the dependence is indeed on the input `consumerIdx` view.
397 auto consumedView =
398 consumer.getBuffer(dependence.indexingOpView.operandIndex);
399 if (!isSameSubView(consumer.getBuffer(consumerIdx), consumedView))
400 continue;
401
402 // Consumer consumes this view, `isStructurallyFusableProducer` also
403 // checks whether it is a strict subview of the producer view.
404 auto producedView =
405 producer.getBuffer(dependence.dependentOpView.operandIndex);
406 LLVM_DEBUG(llvm::dbgs()
407 << "\n"
408 << LinalgDependenceGraph::getDependenceTypeStr(depType)
409 << "producer: " << *producer.getOperation()
410 << " view: " << producedView << " output index: "
411 << dependence.dependentOpView.operandIndex -
412 producer.getNumInputs()
413 << "\n");
414 (void)producedView;
415
416 // Simple fusability checks.
417 if (!isFusableInto(dependenceGraph, consumer, consumedView, producer))
418 continue;
419
420 return dependence;
421 }
422 }
423 return {};
424 }
425
426 Optional<FusionInfo>
fuseProducerOfBuffer(OpBuilder & b,LinalgOp consumer,unsigned consumerIdx,const LinalgDependenceGraph & graph)427 mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer,
428 unsigned consumerIdx,
429 const LinalgDependenceGraph &graph) {
430 Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependence =
431 findFusableProducer(consumer, consumerIdx, graph);
432 if (!fusableDependence)
433 return {};
434
435 LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op);
436 // If producer is already in the same block as consumer, we are done.
437 if (consumer->getBlock() == producerOp->getBlock())
438 return {};
439
440 unsigned producerIdx = fusableDependence->dependentOpView.operandIndex -
441 producerOp.getNumInputs();
442 Value consumerView = consumer.getShapedOperand(consumerIdx);
443
444 // Must be a subview or a slice to guarantee there are loops we can fuse
445 // into.
446 auto subView = consumerView.getDefiningOp<SubViewOp>();
447 auto slice = consumerView.getDefiningOp<SliceOp>();
448 if (!subView && !slice) {
449 LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subview or slice)");
450 return {};
451 }
452
453 // Fuse `producer` just before `consumer`.
454 OpBuilder::InsertionGuard g(b);
455 b.setInsertionPoint(consumer.getOperation());
456 ScopedContext scope(b, consumer.getLoc());
457 LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumer << "\n");
458
459 auto fusedProducer = fuse(b, producerOp, producerIdx, consumer, consumerIdx);
460 return FusionInfo{producerOp, fusedProducer};
461 }
462
463 /// Walk back use-def chain through scf::For yields.
464 /// Sets `producer` and `outputIndex` if it finds a producer LinalgOp
getProducerOfTensor(Value tensor,LinalgOp & producer,unsigned & outputIndex)465 static void getProducerOfTensor(Value tensor, LinalgOp &producer,
466 unsigned &outputIndex) {
467 if (!tensor.getType().isa<RankedTensorType>())
468 return;
469
470 while (true) {
471 if (auto linalgOp = tensor.getDefiningOp<LinalgOp>()) {
472 producer = linalgOp;
473 outputIndex = tensor.cast<OpResult>().getResultNumber();
474 return;
475 }
476 if (auto subTensorOp = tensor.getDefiningOp<SubTensorOp>()) {
477 tensor = subTensorOp.source();
478 continue;
479 }
480 if (auto blockArg = tensor.dyn_cast<BlockArgument>()) {
481 if (auto forOp = blockArg.getDefiningOp<scf::ForOp>()) {
482 tensor = forOp.getResult(blockArg.getArgNumber());
483 continue;
484 }
485 }
486 return;
487 }
488 }
489
fuseProducerOfTensor(OpBuilder & b,LinalgOp consumer,unsigned consumerIdx)490 Optional<FusionInfo> mlir::linalg::fuseProducerOfTensor(OpBuilder &b,
491 LinalgOp consumer,
492 unsigned consumerIdx) {
493 Value inputTensor = consumer.getInput(consumerIdx);
494 LinalgOp producerOp;
495 unsigned producerIdx;
496 getProducerOfTensor(inputTensor, producerOp, producerIdx);
497
498 // Must be a subtensor to guarantee there are loops we can fuse into.
499 auto subTensor = inputTensor.getDefiningOp<SubTensorOp>();
500 if (!subTensor || !producerOp) {
501 LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subtensor)");
502 return {};
503 }
504
505 // If producer is already in the same block as consumer, we are done.
506 if (consumer->getBlock() == producerOp->getBlock())
507 return {};
508
509 // Insert fused `producer` just before `consumer`.
510 OpBuilder::InsertionGuard g(b);
511 b.setInsertionPoint(consumer.getOperation());
512 ScopedContext scope(b, consumer.getLoc());
513 LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumer << "\n");
514 LinalgOp fusedProducer =
515 fuse(b, producerOp, producerIdx, consumer, consumerIdx);
516
517 // Replace use.
518 // Canonicalizations are not guaranteed to have happened before constructing
519 // `fusedProducer`. In the tensor case this can result in temporary type
520 // mismatches. Insert a `tensor_cast` op to propagate the transformation
521 // invariant that types are compatible.
522 Value def = fusedProducer->getResult(producerIdx);
523 OpOperand &use = consumer->getOpOperand(consumerIdx);
524 Type consumerType = use.get().getType();
525 if (consumerType != def.getType())
526 def = b.create<TensorCastOp>(fusedProducer.getLoc(), consumerType, def);
527 use.set(def);
528 return FusionInfo{producerOp, fusedProducer};
529 }
530
531 /// Prune all dimensions that are of reduction iterator type from `map`.
pruneReductionDimsFromMap(ArrayRef<Attribute> iteratorTypes,AffineMap map)532 static AffineMap pruneReductionDimsFromMap(ArrayRef<Attribute> iteratorTypes,
533 AffineMap map) {
534 SmallVector<unsigned, 2> projectedDims;
535 for (auto attr : llvm::enumerate(iteratorTypes)) {
536 if (!isParallelIterator(attr.value()))
537 projectedDims.push_back(attr.index());
538 }
539 return getProjectedMap(map, projectedDims);
540 }
541
542 /// Returns the mapping from iterations in the consumer that write to the same
543 /// location as the iterations in the producer. To do so use
544 /// - indexing map of the fused view in the consumer : consumerIndexMap
545 /// - indexing map of the fused view in the producer : producerIndexMap
546 /// consumerLoopToProducerLoop =
547 /// inverse(producerIndexMap).compose(consumerIndexMap)
getConsumerLoopToProducerLoopMap(LinalgDependenceGraph::LinalgDependenceGraphElem dependence)548 static Optional<AffineMap> getConsumerLoopToProducerLoopMap(
549 LinalgDependenceGraph::LinalgDependenceGraphElem dependence) {
550 auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
551 AffineMap producerIndexingMap =
552 producer.getIndexingMap(dependence.dependentOpView.operandIndex);
553 auto consumer = cast<LinalgOp>(dependence.indexingOpView.op);
554 AffineMap consumerIndexingMap =
555 consumer.getIndexingMap(dependence.indexingOpView.operandIndex);
556
557 AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap(
558 producer.iterator_types().getValue(), producerIndexingMap);
559 if (!prunedProducerIndexingMap.isPermutation())
560 return None;
561
562 if (consumerIndexingMap.getNumResults() !=
563 prunedProducerIndexingMap.getNumResults())
564 return None;
565
566 LLVM_DEBUG({
567 llvm::dbgs() << "\t producerMap : ";
568 producerIndexingMap.print(llvm::dbgs());
569 llvm::dbgs() << " pruned : ";
570 prunedProducerIndexingMap.print(llvm::dbgs());
571 llvm::dbgs() << "\n";
572 llvm::dbgs() << "\t consumerMap : ";
573 consumerIndexingMap.print(llvm::dbgs());
574 llvm::dbgs() << "\n";
575 });
576
577 AffineMap invProducerIndexMap = inversePermutation(prunedProducerIndexingMap);
578 if (!invProducerIndexMap)
579 return None;
580
581 return invProducerIndexMap.compose(consumerIndexingMap);
582 }
583
584 /// Given a projected permutation `map`, returns true if the map changes the
585 /// order in which the fused loop dimension appear.
doesTransposeAccess(AffineMap map,const std::set<unsigned> & fusableLoops)586 static bool doesTransposeAccess(AffineMap map,
587 const std::set<unsigned> &fusableLoops) {
588 Optional<unsigned> lastFusableLoop;
589 for (unsigned pos : llvm::map_range(map.getResults(), [](AffineExpr expr) {
590 return expr.cast<AffineDimExpr>().getPosition();
591 })) {
592 if (!fusableLoops.count(pos))
593 continue;
594 if (!lastFusableLoop) {
595 lastFusableLoop = pos;
596 continue;
597 }
598 if (pos <= lastFusableLoop.getValue())
599 return true;
600 lastFusableLoop = pos;
601 }
602 return false;
603 }
604
605 /// Returns the positions of the loop in `op` that can be tiled based on the
606 /// operations that are to be fused with it. For example, in a
607 ///
608 /// linalg.matmul ins(%a, %b : ...) outs(%c : ...)
609 ///
610 /// if the producer of %a needs to be fused with this op, only the `i` loop of
611 /// the matmul can be tiled while fusing. If producer of %a, and %b are to be
612 /// fused, then no loops can be tiled while fusing. The conditions used are:
613 /// 1. Only parallel loops can be used for tile + fuse. Find the number of
614 /// common outer parallel loops between the op and its producers being fused.
615 /// 2. Of the parallel loops only some can be fused. Only those loops can be
616 /// fused such where the fusable loops iteration space only touches one tile
617 /// of the fused operation. This is because the producer (which is writing
618 /// the fused subview) has update semantics.
619 ///
620 /// Since an inverse computation is needed, we need to consider the projection
621 /// of the producerIndexMap w.r.t the parallel loops. The actual fusable loops
622 /// are the dimensions of the consumerLoopToProducerLoop map that correspond to
623 /// parallel loops and appear in the result of the map
624 ///
625 /// Example 1:
626 /// linalg.fill(%c, %cst)
627 /// linalg.matmul ins(%a, %b) outs(%c)
628 /// Number of parallel loops : 2
629 /// producerIndexMap = affine_map<(i, j) ->(i , j)>
630 /// consumerIndexMap = affine_map<(i, j, k) -> (i, j)>
631 /// consumerLoopToProducerLoop = affine_map<(i, j, k) -> (i, j)>
632 /// Fused dimensions : i, j
633 ///
634 /// Example 2:
635 /// linalg.matmul ins(%a, %b) outs(%c)
636 /// linalg.generic {indexing_maps = [affine_map<(i, j) -> (j, i)>, ...
637 /// iterator_types = ["parallel", "parallel"]}
638 /// ins(%c) ...
639 ///
640 /// Number of parallel loops = 2:
641 /// producerIndexMap (projected to parallel loops) =
642 /// affine_map<(i, j) -> (i, j)>
643 /// consumerLoopToProducerLoop2 = affine_map<(i, j) -> (j, i)>
644 /// Fused dimensions : i, j
645 ///
646 /// Example 3:
647 /// linalg.copy(%s, %b)
648 /// linalg.matmul ins(%a, %b) outs(%c)
649 ///
650 /// Number of parallel loops = 2
651 /// produceIndexMap : affine_map<(i, j) -> (i, j)>
652 /// consumerLoopToProduceLoops = affine_map<(i, j, k) -> (k, j)>
653 /// submap with only parallel loops = affine_map<(i, j) -> (j)>
654 /// Fused dimensions : j
655 static std::set<unsigned>
collectFusableLoops(ArrayRef<LinalgOp> ops,const FusableOpDependencesTy & fusableDependences)656 collectFusableLoops(ArrayRef<LinalgOp> ops,
657 const FusableOpDependencesTy &fusableDependences) {
658 assert(!ops.empty());
659 auto getNumOuterParallelLoops = [](LinalgOp linalgOp) {
660 return linalgOp.iterator_types()
661 .getValue()
662 .take_while([](Attribute attr) -> bool {
663 return attr.cast<StringAttr>().getValue() ==
664 getParallelIteratorTypeName();
665 })
666 .size();
667 };
668
669 size_t numOuterParallelLoops = getNumOuterParallelLoops(ops.back());
670 for (auto op : ops.drop_back()) {
671 numOuterParallelLoops =
672 std::min(numOuterParallelLoops, getNumOuterParallelLoops(op));
673 }
674
675 std::set<unsigned> fusableLoops;
676 auto range = llvm::seq<unsigned>(0, numOuterParallelLoops);
677 fusableLoops.insert(range.begin(), range.end());
678
679 for (auto op : reverse(ops)) {
680 for (auto dependence : fusableDependences.lookup(op)) {
681 LLVM_DEBUG({
682 llvm::dbgs() << "\t fusable :";
683 for (unsigned i : fusableLoops)
684 llvm::dbgs() << " " << i;
685 llvm::dbgs() << "\n";
686 });
687
688 Optional<AffineMap> consumerLoopToProducerLoop =
689 getConsumerLoopToProducerLoopMap(dependence);
690 if (!consumerLoopToProducerLoop) {
691 op.emitRemark("failed to get map from consumer loop to producer loop");
692 return {};
693 }
694 // todo: This condition is only an implementation limitation. When fusing
695 // the operation, if the accesses in the producer/consumer are transposes
696 // of each other, the loop bounds for the tiled producer can be
697 // manipulated accordingly. This requires some additional bookkeeping in
698 // the implementation of tile+fuse that is defered to later.
699 if (doesTransposeAccess(*consumerLoopToProducerLoop, fusableLoops)) {
700 op.emitRemark("unhandled fusion when fusion requires permutation");
701 return {};
702 }
703
704 std::set<unsigned> candidates;
705 for (AffineExpr expr : consumerLoopToProducerLoop->getResults()) {
706 unsigned position = expr.cast<AffineDimExpr>().getPosition();
707 if (fusableLoops.count(position))
708 candidates.insert(position);
709 }
710 LLVM_DEBUG({
711 llvm::dbgs() << "\t candidates :";
712 for (unsigned i : candidates)
713 llvm::dbgs() << " " << i;
714 llvm::dbgs() << "\n";
715 });
716 if (candidates.empty())
717 return {};
718 std::swap(candidates, fusableLoops);
719 }
720 }
721
722 return fusableLoops;
723 }
724
725 /// Find all dependences that are fusable.
findAllFusableDependences(ArrayRef<LinalgOp> ops,const LinalgDependenceGraph & dependenceGraph)726 FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
727 ArrayRef<LinalgOp> ops, const LinalgDependenceGraph &dependenceGraph) {
728 FusableOpDependencesTy fusableDependences;
729 // TODO: Currently fusion would not be legal if the fusable dependence is to
730 // the same producer but different indexing map in the consumer. Fix this, but
731 // in the meanwhile disallow such a fusion.
732 DenseMap<Operation *, AffineMap> fusedProducerIndexingMap;
733 for (LinalgOp op : reverse(ops)) {
734 for (auto operandIndex :
735 llvm::seq<unsigned>(0, op.getNumInputsAndOutputBuffers())) {
736 Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
737 fusableDependence =
738 findFusableProducer(op, operandIndex, dependenceGraph);
739 if (!fusableDependence)
740 continue;
741 LinalgOp producerOp =
742 cast<LinalgOp>(fusableDependence->dependentOpView.op);
743 // Do not fuse dependences that are to operations not in the same basic
744 // block. This avoid moving fused operations across loops that might
745 // themselves carry dependency making the fusion illegal.
746 if (producerOp->getBlock() != op->getBlock()) {
747 op.emitRemark("unhandled fusion of ops in different basic blocks");
748 return FusableOpDependencesTy{};
749 }
750 // Make sure that the indexing map of the view used for fusion in the
751 // producer is a projected permutation.
752 unsigned producerIdx = fusableDependence->dependentOpView.operandIndex;
753 AffineMap producerMap = producerOp.getIndexingMap(producerIdx);
754 if (!producerMap.isProjectedPermutation()) {
755 op.emitRemark(
756 "unhandled non permutation indexing map for fused view in "
757 "producer for operand at index ")
758 << operandIndex;
759 return FusableOpDependencesTy{};
760 }
761
762 unsigned consumerIdx = fusableDependence->indexingOpView.operandIndex;
763 AffineMap consumerMap = op.getIndexingMap(consumerIdx);
764 if (!consumerMap.isProjectedPermutation()) {
765 op.emitRemark(
766 "unhandled case where indexing map for fused view in the consumer "
767 "is "
768 "not a projected permuration while fusing at index ")
769 << operandIndex;
770 return FusableOpDependencesTy{};
771 }
772
773 // Check if the producer is already a fusion candidate. Cannot fuse this
774 // dependence if it has a different indexing map when used in the
775 // consumer.
776 if (fusedProducerIndexingMap.count(producerOp.getOperation()) &&
777 fusedProducerIndexingMap[producerOp.getOperation()] != consumerMap) {
778 op.emitRemark(
779 "unhandled fusion to the same producer but with different "
780 "indexing maps");
781 return FusableOpDependencesTy{};
782 }
783 fusedProducerIndexingMap[producerOp.getOperation()] = consumerMap;
784
785 fusableDependences[producerOp.getOperation()].push_back(
786 *fusableDependence);
787 }
788 }
789 return fusableDependences;
790 }
791
792 /// Tile the fused loops in the root operation, by setting the tile sizes for
793 /// all other loops to zero (those will be tiled later).
tileRootOperation(OpBuilder & builder,LinalgOp op,ArrayRef<Value> tileSizeVector,const LinalgTilingOptions & options,const std::set<unsigned> & fusedLoops)794 static Optional<TiledLinalgOp> tileRootOperation(
795 OpBuilder &builder, LinalgOp op, ArrayRef<Value> tileSizeVector,
796 const LinalgTilingOptions &options, const std::set<unsigned> &fusedLoops) {
797 SmallVector<Value, 4> tileSizes(tileSizeVector.begin(), tileSizeVector.end());
798 auto zero = std_constant_index(0);
799 for (unsigned i = 0, e = tileSizes.size(); i != e; ++i)
800 if (!fusedLoops.count(i))
801 tileSizes[i] = zero;
802 LinalgTilingOptions tileFusedLoopsOptions = options;
803 tileFusedLoopsOptions.setTileSizes(tileSizes);
804 return tileLinalgOp(builder, op, tileFusedLoopsOptions);
805 }
806
807 /// Fuse the operations in `fusionCandidates` with `tiledOp`. Latter is expected
808 /// to be a tiled operation such that it is valid to fuse all operations in
809 /// `fusionCandidates`, i.e. move the operation within the inter-tile loops of
810 /// `tiledOp`.
811 static SmallVector<LinalgOp, 1>
fuseOperations(OpBuilder & builder,LinalgOp tiledOp,ArrayRef<LinalgOp> fusionCandidates,const FusableOpDependencesTy & fusableDependences,const std::set<unsigned> & fusedLoops)812 fuseOperations(OpBuilder &builder, LinalgOp tiledOp,
813 ArrayRef<LinalgOp> fusionCandidates,
814 const FusableOpDependencesTy &fusableDependences,
815 const std::set<unsigned> &fusedLoops) {
816 OpBuilder::InsertionGuard guard(builder);
817 builder.setInsertionPoint(tiledOp);
818 DenseMap<unsigned, Range> fusedLoopsAndRanges;
819 for (unsigned loop : fusedLoops) {
820 ShapeDimension shapeDim = getShapeDefiningLoopRange(tiledOp, loop, true);
821 fusedLoopsAndRanges[loop] = getRangeFromOperandShape(
822 builder, tiledOp.getLoc(), shapeDim.shape, shapeDim.dimension);
823 }
824
825 SmallVector<LinalgOp, 1> fusedOps(fusionCandidates.size());
826 for (auto candidate : enumerate(llvm::reverse(fusionCandidates))) {
827 LinalgOp fusedOp = fuse(builder, candidate.value(), fusedLoopsAndRanges);
828 fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp;
829 builder.setInsertionPoint(fusedOp);
830 }
831 return fusedOps;
832 }
833
834 template <typename LoopType>
835 static Optional<TiledAndFusedLinalgOps>
tileAndFuseLinalgOpsImpl(OpBuilder & builder,ArrayRef<LinalgOp> ops,const LinalgDependenceGraph & dependenceGraph,const LinalgTilingOptions & tilingOptions)836 tileAndFuseLinalgOpsImpl(OpBuilder &builder, ArrayRef<LinalgOp> ops,
837 const LinalgDependenceGraph &dependenceGraph,
838 const LinalgTilingOptions &tilingOptions) {
839 if (ops.empty())
840 return llvm::None;
841 LinalgOp rootOp = ops.back();
842 for (auto op : enumerate(ops)) {
843 // TODO: Nothing in the fusion of sequence of ops is specific to
844 // buffers. This check can be removed after it is tested on tensors.
845 LinalgOp linalgOp = op.value();
846 if (!linalgOp.hasBufferSemantics()) {
847 linalgOp.emitError("tile and fuse only tested for buffer operation");
848 return llvm::None;
849 }
850 }
851 // TODO: Support interchange with tile + fuse. This might actually help do
852 // better fusion.
853 if (!tilingOptions.interchangeVector.empty()) {
854 rootOp.emitError("unable to handle tile and fuse with interchange");
855 return llvm::None;
856 }
857
858 OpBuilder::InsertionGuard guard(builder);
859 builder.setInsertionPoint(rootOp);
860 ScopedContext scope(builder, rootOp.getLoc());
861
862 // Find all the producers.
863 FusableOpDependencesTy fusableDependences =
864 findAllFusableDependences(ops, dependenceGraph);
865 if (fusableDependences.empty())
866 return llvm::None;
867
868 TiledAndFusedLinalgOps ret;
869 // Find the loops that can be tiled and fused.
870 ret.fusedLoopDims = collectFusableLoops(ops, fusableDependences);
871
872 // If there are no fusable dependences or there are no tile+fusable loops,
873 // just return.
874 if (ret.fusedLoopDims.empty()) {
875 return llvm::None;
876 }
877
878 // Tile the fused loops in the last operation in the list.
879 SmallVector<Value, 4> tileSizeVector =
880 tilingOptions.tileSizeComputationFunction(builder, rootOp);
881 Optional<TiledLinalgOp> tiledRootOp = tileRootOperation(
882 builder, rootOp, tileSizeVector, tilingOptions, ret.fusedLoopDims);
883 if (!tiledRootOp) {
884 rootOp.emitError("failed to tile the fused loops");
885 return llvm::None;
886 }
887 ret.op = tiledRootOp->op;
888 ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end());
889
890 // Fuse the other operations into the fused inter-tile loops produced above.
891 ret.fusedProducers = fuseOperations(builder, ret.op, ops.drop_back(),
892 fusableDependences, ret.fusedLoopDims);
893 return ret;
894 }
895
896 Optional<TiledAndFusedLinalgOps>
tileAndFuseLinalgOps(OpBuilder & builder,ArrayRef<LinalgOp> ops,const LinalgDependenceGraph & dependenceGraph,const LinalgTilingOptions & tilingOptions)897 mlir::linalg::tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
898 const LinalgDependenceGraph &dependenceGraph,
899 const LinalgTilingOptions &tilingOptions) {
900 switch (tilingOptions.loopType) {
901 case LinalgTilingLoopType::Loops:
902 return tileAndFuseLinalgOpsImpl<scf::ForOp>(builder, ops, dependenceGraph,
903 tilingOptions);
904 case LinalgTilingLoopType::ParallelLoops:
905 return tileAndFuseLinalgOpsImpl<scf::ParallelOp>(
906 builder, ops, dependenceGraph, tilingOptions);
907 default:;
908 }
909 return llvm::None;
910 }
911