• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
14 #include "mlir/Dialect/Affine/IR/AffineOps.h"
15 #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h"
16 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
17 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
18 #include "mlir/Dialect/Linalg/Passes.h"
19 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
20 #include "mlir/Dialect/Linalg/Utils/Utils.h"
21 #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h"
22 #include "mlir/IR/AffineExpr.h"
23 #include "mlir/IR/AffineMap.h"
24 #include "mlir/IR/Dominance.h"
25 #include "mlir/Support/LLVM.h"
26 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27 #include "llvm/ADT/MapVector.h"
28 #include "llvm/Support/CommandLine.h"
29 #include "llvm/Support/Debug.h"
30 
31 #include <set>
32 
33 #define DEBUG_TYPE "linalg-fusion"
34 
35 using namespace mlir;
36 using namespace mlir::edsc;
37 using namespace mlir::edsc::intrinsics;
38 using namespace mlir::linalg;
39 
40 using llvm::dbgs;
41 
42 /// Implements a simple high-level fusion pass on linalg structured operations.
43 ///
44 /// In each block, linalg ops are processed in reverse textual order.
45 /// Given a linalg op `O`, fusion occurs by:
46 ///   1. inspecting the linalg ops that write into the views read by `O`. There
47 ///      are 2 cases:
48 ///      a) buffer case: use the SSA value of the views and a simple alias
49 ///         analysis on subview ops to determine producer-consumer dependences;
50 ///      b) tensor case: use SSA use-def chains on subtensor ops;
51 ///   2. greedily fuse the linalg ops that produce the subview/subtensor.
52 ///   3. inspect the fused ops and determine whether they have other remaining
53 ///      LinalgOp uses. If not, then erase the original producing linalg op.
54 ///
55 /// More advanced use cases, analyses as well as profitability heuristics are
56 /// left for future work.
57 
58 // Fill `offset`, `sizes` and `strides` used to iterate over the shape indexed
59 // by `permutationMap`.
inferShapeComponents(AffineMap permutationMap,ArrayRef<Range> loopRanges,SmallVectorImpl<Value> & offsets,SmallVectorImpl<Value> & sizes,SmallVectorImpl<Value> & strides)60 static void inferShapeComponents(AffineMap permutationMap,
61                                  ArrayRef<Range> loopRanges,
62                                  SmallVectorImpl<Value> &offsets,
63                                  SmallVectorImpl<Value> &sizes,
64                                  SmallVectorImpl<Value> &strides) {
65   assert(permutationMap.isProjectedPermutation() &&
66          "expected some subset of a permutation map");
67   SmallVector<Range, 4> shapeRanges(permutationMap.getNumResults());
68   unsigned idx = 0;
69   for (AffineExpr e : permutationMap.getResults()) {
70     // loopToOperandRangesMaps are permutations-only, just swap indices.
71     unsigned loopPos = e.cast<AffineDimExpr>().getPosition();
72     shapeRanges[idx++] = loopRanges[loopPos];
73   }
74   // Construct a new subshape for the tile.
75   unsigned rank = shapeRanges.size();
76   offsets.reserve(rank);
77   sizes.reserve(rank);
78   strides.reserve(rank);
79   for (auto r : shapeRanges) {
80     offsets.push_back(r.offset);
81     sizes.push_back(r.size);
82     strides.push_back(r.stride);
83   }
84 }
85 
86 // Return a cloned version of `op` that operates on `loopRanges`, assumed to be
87 // a subset of the original loop ranges of `op`.
88 // This is achieved by applying the `loopToOperandRangesMaps` permutation maps
89 // to the `loopRanges` in order to obtain view ranges.
cloneWithLoopRanges(OpBuilder & b,Location loc,LinalgOp op,ArrayRef<Range> loopRanges)90 static LinalgOp cloneWithLoopRanges(OpBuilder &b, Location loc, LinalgOp op,
91                                     ArrayRef<Range> loopRanges) {
92   SmallVector<Value, 8> clonedShapes;
93   clonedShapes.reserve(op.getNumShapedOperands());
94 
95   // Iterate over the shape operands in order.
96   // Extract the subranges from the linearized ranges.
97   for (auto en : llvm::enumerate(op.getShapedOperands())) {
98     unsigned shapedOperandIdx = en.index();
99     AffineMap map = op.getIndexingMap(shapedOperandIdx);
100     LLVM_DEBUG(llvm::dbgs() << "shapedOperandIdx: " << shapedOperandIdx
101                             << " with indexingMap: " << map << "\n");
102     SmallVector<Value, 4> offsets, sizes, strides;
103     inferShapeComponents(map, loopRanges, offsets, sizes, strides);
104     Value shape = en.value();
105     Value sub = shape.getType().isa<MemRefType>()
106                     ? b.create<SubViewOp>(loc, shape, offsets, sizes, strides)
107                           .getResult()
108                     : b.create<SubTensorOp>(loc, shape, offsets, sizes, strides)
109                           .getResult();
110     clonedShapes.push_back(sub);
111   }
112   // Append the other operands.
113   auto operands = op.getAssumedNonShapedOperands();
114   clonedShapes.append(operands.begin(), operands.end());
115 
116   // Iterate over the results in order.
117   // Extract the subtensor type from the linearized range.
118   // Since we do not enforce any canonicalizations on the fly, this is always
119   // fully dynamic at construction time.
120   SmallVector<Type, 4> resultTypes;
121   resultTypes.reserve(op->getNumResults());
122   for (RankedTensorType t : op.getOutputTensorTypes()) {
123     unsigned rank = t.getRank();
124     SmallVector<int64_t, 4> staticOffsetsVector(
125         rank, ShapedType::kDynamicStrideOrOffset);
126     SmallVector<int64_t, 4> staticSizesVector(rank, ShapedType::kDynamicSize);
127     SmallVector<int64_t, 4> staticStridesVector(
128         rank, ShapedType::kDynamicStrideOrOffset);
129     resultTypes.push_back(SubTensorOp::inferResultType(
130         t.cast<RankedTensorType>(), staticOffsetsVector, staticSizesVector,
131         staticStridesVector));
132   }
133 
134   Operation *clonedOp = op.clone(b, loc, resultTypes, clonedShapes);
135   // When the producer is an IndexedGenericOp, we have to transform its block
136   // IV arguments according to the tiling of the consumer, i.e. offset them by
137   // the values computed in `loopRanges`.
138   if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(clonedOp)) {
139     auto &block = indexedGenericOp.region().front();
140     OpBuilder::InsertionGuard g(b);
141     b.setInsertionPointToStart(&block);
142     for (unsigned i = 0, e = indexedGenericOp.getNumLoops(); i < e; ++i) {
143       Value oldIndex = block.getArgument(i);
144       // TODO: replace by an affine_apply.
145       AddIOp newIndex = b.create<AddIOp>(indexedGenericOp.getLoc(), oldIndex,
146                                          loopRanges[i].offset);
147       oldIndex.replaceAllUsesExcept(newIndex,
148                                     SmallPtrSet<Operation *, 1>{newIndex});
149     }
150   }
151 
152   return clonedOp;
153 }
154 
155 struct ShapeDimension {
156   Value shape;
157   unsigned dimension;
158 };
159 
160 // Given an `op`, returns the first (`shape`, `dimension`) pair that identifies
161 // the loop range at `loopDepth`. The semantics of the loopToOperandRangesMaps
162 // guarantees at least one such dimension is found. If multiple candidates exist
163 // they must agree by construction (i.e. have the same size) and we just return
164 // the first one.
165 static ShapeDimension
getShapeDefiningLoopRange(LinalgOp op,unsigned loopDepth,bool fromSubViewOpOnly=false)166 getShapeDefiningLoopRange(LinalgOp op, unsigned loopDepth,
167                           bool fromSubViewOpOnly = false) {
168   auto maps = op.indexing_maps();
169   // Iterate over the inputs and outputs in order.
170   // Extract the subranges from the linearized ranges.
171   SmallVector<Value, 8> ios(op.getInputsAndOutputBuffers());
172   for (auto en : llvm::enumerate(ios)) {
173     // The method `getRangeFromOperandShape` requires using SubViewOp or
174     // SubTensorOps. If the value isnt defined from there continue.
175     // todo: The method should be adapted to get the values from
176     // `ViewInterface`. The interface needs a `getOrCreateRanges` method which
177     // currently returns a `linalg.range`. The fix here is to move this op to
178     // `std` dialect and add the method to `ViewInterface`.
179     if (fromSubViewOpOnly &&
180         !isa_and_nonnull<SubViewOp, SubTensorOp>(en.value().getDefiningOp()))
181       continue;
182 
183     unsigned idx = en.index();
184     auto map = maps[idx].cast<AffineMapAttr>().getValue();
185     LLVM_DEBUG(llvm::dbgs()
186                << "getShapeDefiningLoopRange I/O idx: " << idx << "\n");
187     LLVM_DEBUG(llvm::dbgs()
188                << "getShapeDefiningLoopRange map: " << map << "\n");
189     Value shape = en.value();
190     SmallVector<Value, 8> shapeRanges(map.getNumResults(), nullptr);
191     for (auto en2 : llvm::enumerate(map.getResults())) {
192       auto dimExpr = en2.value().dyn_cast<AffineDimExpr>();
193       if (!dimExpr)
194         continue;
195       if (loopDepth == en2.value().cast<AffineDimExpr>().getPosition()) {
196         LLVM_DEBUG(llvm::dbgs() << "getShapeDefiningLoopRange loopDepth: "
197                                 << loopDepth << "\n");
198         LLVM_DEBUG(llvm::dbgs()
199                    << "getShapeDefiningLoopRange shape: " << shape << "\n");
200         return ShapeDimension{shape, static_cast<unsigned>(en2.index())};
201       }
202     }
203   }
204   llvm_unreachable("Expect to be able to extract a shape defining loop range");
205 }
206 
207 /// Fuse the producer by cloning the `producer`. The `fusedLoopsAndRanges`
208 /// provides the loop range information for the fused loops. The rest are
209 /// obtained from the producer itself, since they are not tiled + fused.
fuse(OpBuilder & b,LinalgOp producer,const DenseMap<unsigned,Range> & fusedLoopsAndRanges)210 static LinalgOp fuse(OpBuilder &b, LinalgOp producer,
211                      const DenseMap<unsigned, Range> &fusedLoopsAndRanges) {
212 
213   unsigned nPar = producer.getNumParallelLoops();
214   unsigned nRed = producer.getNumReductionLoops();
215   unsigned nWin = producer.getNumWindowLoops();
216   SmallVector<Range, 8> loopRanges(nPar + nRed + nWin);
217   for (auto fusedLoops : fusedLoopsAndRanges)
218     loopRanges[fusedLoops.first] = fusedLoops.second;
219 
220   // Iterate over all dimensions. For the dimensions not identified by the
221   // producer map for `producerIdx`, we need to explicitly compute the shape
222   // that defines the loop ranges using the `producer`.
223   for (unsigned i = 0, nLoops = loopRanges.size(); i < nLoops; ++i) {
224     if (loopRanges[i].offset)
225       LLVM_DEBUG(llvm::dbgs()
226                  << "existing LoopRange: " << loopRanges[i] << "\n");
227     else {
228       auto shapeDim = getShapeDefiningLoopRange(producer, i);
229       loopRanges[i] = Range{std_constant_index(0),
230                             std_dim(shapeDim.shape, shapeDim.dimension),
231                             std_constant_index(1)};
232       LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n");
233     }
234   }
235 
236   return cloneWithLoopRanges(b, producer.getLoc(), producer, loopRanges);
237 }
238 
239 /// Get the loop range for a dimension `dim` based on the `shapedOperand`. It is
240 /// expected to be defined by a subview op or a subtensor op.
getRangeFromOperandShape(OpBuilder & b,Location loc,Value shapedOperand,unsigned dim)241 static Range getRangeFromOperandShape(OpBuilder &b, Location loc,
242                                       Value shapedOperand, unsigned dim) {
243   Operation *shapeProducingOp = shapedOperand.getDefiningOp();
244   if (auto subViewOp = dyn_cast<SubViewOp>(shapeProducingOp))
245     return subViewOp.getOrCreateRanges(b, loc)[dim];
246   if (auto subTensorOp = dyn_cast<SubTensorOp>(shapeProducingOp))
247     return subTensorOp.getOrCreateRanges(b, loc)[dim];
248   llvm_unreachable("SubviewOp or SubTensorOp expected");
249 }
250 
251 /// Fuses the producer of `producerIdx` into the loop immediately enclosing
252 /// `consumer`. This is achieved by "recomputing" the `producer` at the time it
253 /// is needed just before the `consumer.
254 ///
255 /// Depending on the type of `consumer.getShapedOperand(consumerIdx)`, there are
256 /// 2 cases:
257 ///   1. Buffer case: `producerIdx` is the index of the buffer in
258 ///      `producer.getOutputBuffers()`.
259 ///   2. Tensor case: `producerIdx` is the index of the tensor in
260 ///      `producer.getResults()`.
fuse(OpBuilder & b,LinalgOp producer,unsigned producerIdx,LinalgOp consumer,unsigned consumerIdx)261 static LinalgOp fuse(OpBuilder &b, LinalgOp producer, unsigned producerIdx,
262                      LinalgOp consumer, unsigned consumerIdx) {
263   AffineMap producerMap = producer.getOutputIndexingMap(producerIdx);
264   LLVM_DEBUG(llvm::dbgs() << "Producer Idx: " << producerIdx
265                           << ", producer map: " << producerMap << "\n");
266   DenseMap<unsigned, Range> fusedLoopsAndRanges;
267   Location loc = consumer.getLoc();
268   Value shapedOperand = consumer.getShapedOperand(consumerIdx);
269   for (auto en : llvm::enumerate(producerMap.getResults())) {
270     unsigned posInProducerLoop = en.value().cast<AffineDimExpr>().getPosition();
271     fusedLoopsAndRanges[posInProducerLoop] =
272         getRangeFromOperandShape(b, loc, shapedOperand, en.index());
273   }
274   return fuse(b, producer, fusedLoopsAndRanges);
275 }
276 
277 // Encode structural fusion safety preconditions.
278 // Some of these will be lifted in the future with better analysis.
isStructurallyFusableProducer(LinalgOp producer,Value consumedView,LinalgOp consumer)279 static bool isStructurallyFusableProducer(LinalgOp producer, Value consumedView,
280                                           LinalgOp consumer) {
281   assert(producer.hasBufferSemantics() &&
282          "expected linalg op with buffer semantics");
283   assert(consumer.hasBufferSemantics() &&
284          "expected linalg op with buffer semantics");
285   if (producer.getNumOutputs() != 1) {
286     LLVM_DEBUG(llvm::dbgs() << "\nNot structurally fusable (multi-output)");
287     return false;
288   }
289   // Only fuse when the producer block dominates.
290   DominanceInfo dom(producer.getOperation());
291   if (!dom.dominates(producer->getBlock(), consumer->getBlock())) {
292     LLVM_DEBUG(
293         llvm::dbgs()
294         << "\nNot structurally fusable (producer block does not dominate)");
295     return false;
296   }
297   return true;
298 }
299 
isProducerLastWriteOfView(const LinalgDependenceGraph & graph,LinalgOp consumer,Value consumedView,LinalgOp producer)300 bool mlir::linalg::isProducerLastWriteOfView(const LinalgDependenceGraph &graph,
301                                              LinalgOp consumer,
302                                              Value consumedView,
303                                              LinalgOp producer) {
304   assert(producer.hasBufferSemantics() &&
305          "expected linalg op with buffer semantics");
306   assert(consumer.hasBufferSemantics() &&
307          "expected linalg op with buffer semantics");
308   // Make some simple structural checks that alleviate the need for more
309   // complex analyses.
310   if (!isStructurallyFusableProducer(producer, consumedView, consumer)) {
311     LLVM_DEBUG(llvm::dbgs() << "\n***Not static last write due to structure:\t"
312                             << *producer.getOperation());
313     return false;
314   }
315   // Check for any interleaved write to consumedView.
316   if (!graph.findCoveringWrites(producer, consumer, consumedView).empty()) {
317     LLVM_DEBUG(llvm::dbgs() << "\n***Not fusable due to interleaved write:\t"
318                             << *producer.getOperation());
319     return false;
320   }
321   return true;
322 }
323 
isFusableInto(const LinalgDependenceGraph & graph,LinalgOp consumer,Value consumedView,LinalgOp producer)324 bool mlir::linalg::isFusableInto(const LinalgDependenceGraph &graph,
325                                  LinalgOp consumer, Value consumedView,
326                                  LinalgOp producer) {
327   assert(producer.hasBufferSemantics() &&
328          "expected linalg op with buffer semantics");
329   assert(consumer.hasBufferSemantics() &&
330          "expected linalg op with buffer semantics");
331   if (!isProducerLastWriteOfView(graph, consumer, consumedView, producer))
332     return false;
333   // Check for any fusion-preventing dependence to any shape read/written that
334   // would violate dependences.
335   if (!graph.findCoveringDependences(producer, consumer).empty()) {
336     LLVM_DEBUG(llvm::dbgs()
337                << "\n***Not fusable due to an interleaved dependence:\t"
338                << *producer.getOperation());
339     return false;
340   }
341   if (auto convOp = dyn_cast<linalg::ConvOp>(producer.getOperation())) {
342     // TODO: add a level of indirection to linalg.generic.
343     if (convOp.padding())
344       return false;
345   }
346   if (auto convOp = dyn_cast<linalg::ConvOp>(consumer.getOperation())) {
347     // TODO: add a level of indirection to linalg.generic.
348     if (convOp.padding())
349       return false;
350   }
351   return true;
352 }
353 
isSameSubView(Value a,Value b)354 static bool isSameSubView(Value a, Value b) {
355   if (a == b)
356     return true;
357   auto sva = a.getDefiningOp<SubViewOp>();
358   auto svb = b.getDefiningOp<SubViewOp>();
359   if (!sva || !svb)
360     return false;
361   if (!isSameSubView(sva.getViewSource(), svb.getViewSource()))
362     return false;
363   if (sva.getType() != svb.getType())
364     return false;
365   if (sva.getNumOperands() != svb.getNumOperands())
366     return false;
367   if (sva.static_offsets() != svb.static_offsets())
368     return false;
369   if (sva.static_sizes() != svb.static_sizes())
370     return false;
371   if (sva.static_strides() != svb.static_strides())
372     return false;
373   /// Skip the "source" operand.
374   for (unsigned idx = 1, e = sva.getNumOperands(); idx != e; ++idx)
375     if (sva.getOperand(idx) != svb.getOperand(idx))
376       return false;
377   return true;
378 }
379 
380 static Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
findFusableProducer(LinalgOp consumer,unsigned consumerIdx,const LinalgDependenceGraph & dependenceGraph)381 findFusableProducer(LinalgOp consumer, unsigned consumerIdx,
382                     const LinalgDependenceGraph &dependenceGraph) {
383   // Only consider RAW and WAW atm.
384   for (auto depType : {
385            LinalgDependenceGraph::DependenceType::RAW,
386            LinalgDependenceGraph::DependenceType::WAW,
387        }) {
388     for (auto dependence : llvm::make_filter_range(
389              dependenceGraph.getDependencesInto(consumer, depType),
390              [consumerIdx](
391                  LinalgDependenceGraph::LinalgDependenceGraphElem elem) {
392                return elem.indexingOpView.operandIndex == consumerIdx;
393              })) {
394       auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
395 
396       // Check that the dependence is indeed on the input `consumerIdx` view.
397       auto consumedView =
398           consumer.getBuffer(dependence.indexingOpView.operandIndex);
399       if (!isSameSubView(consumer.getBuffer(consumerIdx), consumedView))
400         continue;
401 
402       // Consumer consumes this view, `isStructurallyFusableProducer` also
403       // checks whether it is a strict subview of the producer view.
404       auto producedView =
405           producer.getBuffer(dependence.dependentOpView.operandIndex);
406       LLVM_DEBUG(llvm::dbgs()
407                  << "\n"
408                  << LinalgDependenceGraph::getDependenceTypeStr(depType)
409                  << "producer: " << *producer.getOperation()
410                  << " view: " << producedView << " output index: "
411                  << dependence.dependentOpView.operandIndex -
412                         producer.getNumInputs()
413                  << "\n");
414       (void)producedView;
415 
416       // Simple fusability checks.
417       if (!isFusableInto(dependenceGraph, consumer, consumedView, producer))
418         continue;
419 
420       return dependence;
421     }
422   }
423   return {};
424 }
425 
426 Optional<FusionInfo>
fuseProducerOfBuffer(OpBuilder & b,LinalgOp consumer,unsigned consumerIdx,const LinalgDependenceGraph & graph)427 mlir::linalg::fuseProducerOfBuffer(OpBuilder &b, LinalgOp consumer,
428                                    unsigned consumerIdx,
429                                    const LinalgDependenceGraph &graph) {
430   Optional<LinalgDependenceGraph::LinalgDependenceGraphElem> fusableDependence =
431       findFusableProducer(consumer, consumerIdx, graph);
432   if (!fusableDependence)
433     return {};
434 
435   LinalgOp producerOp = cast<LinalgOp>(fusableDependence->dependentOpView.op);
436   // If producer is already in the same block as consumer, we are done.
437   if (consumer->getBlock() == producerOp->getBlock())
438     return {};
439 
440   unsigned producerIdx = fusableDependence->dependentOpView.operandIndex -
441                          producerOp.getNumInputs();
442   Value consumerView = consumer.getShapedOperand(consumerIdx);
443 
444   // Must be a subview or a slice to guarantee there are loops we can fuse
445   // into.
446   auto subView = consumerView.getDefiningOp<SubViewOp>();
447   auto slice = consumerView.getDefiningOp<SliceOp>();
448   if (!subView && !slice) {
449     LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subview or slice)");
450     return {};
451   }
452 
453   // Fuse `producer` just before `consumer`.
454   OpBuilder::InsertionGuard g(b);
455   b.setInsertionPoint(consumer.getOperation());
456   ScopedContext scope(b, consumer.getLoc());
457   LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumer << "\n");
458 
459   auto fusedProducer = fuse(b, producerOp, producerIdx, consumer, consumerIdx);
460   return FusionInfo{producerOp, fusedProducer};
461 }
462 
463 /// Walk back use-def chain through scf::For yields.
464 /// Sets `producer` and `outputIndex` if it finds a producer LinalgOp
getProducerOfTensor(Value tensor,LinalgOp & producer,unsigned & outputIndex)465 static void getProducerOfTensor(Value tensor, LinalgOp &producer,
466                                 unsigned &outputIndex) {
467   if (!tensor.getType().isa<RankedTensorType>())
468     return;
469 
470   while (true) {
471     if (auto linalgOp = tensor.getDefiningOp<LinalgOp>()) {
472       producer = linalgOp;
473       outputIndex = tensor.cast<OpResult>().getResultNumber();
474       return;
475     }
476     if (auto subTensorOp = tensor.getDefiningOp<SubTensorOp>()) {
477       tensor = subTensorOp.source();
478       continue;
479     }
480     if (auto blockArg = tensor.dyn_cast<BlockArgument>()) {
481       if (auto forOp = blockArg.getDefiningOp<scf::ForOp>()) {
482         tensor = forOp.getResult(blockArg.getArgNumber());
483         continue;
484       }
485     }
486     return;
487   }
488 }
489 
fuseProducerOfTensor(OpBuilder & b,LinalgOp consumer,unsigned consumerIdx)490 Optional<FusionInfo> mlir::linalg::fuseProducerOfTensor(OpBuilder &b,
491                                                         LinalgOp consumer,
492                                                         unsigned consumerIdx) {
493   Value inputTensor = consumer.getInput(consumerIdx);
494   LinalgOp producerOp;
495   unsigned producerIdx;
496   getProducerOfTensor(inputTensor, producerOp, producerIdx);
497 
498   // Must be a subtensor to guarantee there are loops we can fuse into.
499   auto subTensor = inputTensor.getDefiningOp<SubTensorOp>();
500   if (!subTensor || !producerOp) {
501     LLVM_DEBUG(llvm::dbgs() << "\nNot fusable (not a subtensor)");
502     return {};
503   }
504 
505   // If producer is already in the same block as consumer, we are done.
506   if (consumer->getBlock() == producerOp->getBlock())
507     return {};
508 
509   // Insert fused `producer` just before `consumer`.
510   OpBuilder::InsertionGuard g(b);
511   b.setInsertionPoint(consumer.getOperation());
512   ScopedContext scope(b, consumer.getLoc());
513   LLVM_DEBUG(llvm::dbgs() << "Fuse into consumer: " << *consumer << "\n");
514   LinalgOp fusedProducer =
515       fuse(b, producerOp, producerIdx, consumer, consumerIdx);
516 
517   // Replace use.
518   // Canonicalizations are not guaranteed to have happened before constructing
519   // `fusedProducer`. In the tensor case this can result in temporary type
520   // mismatches. Insert a `tensor_cast` op to propagate the transformation
521   // invariant that types are compatible.
522   Value def = fusedProducer->getResult(producerIdx);
523   OpOperand &use = consumer->getOpOperand(consumerIdx);
524   Type consumerType = use.get().getType();
525   if (consumerType != def.getType())
526     def = b.create<TensorCastOp>(fusedProducer.getLoc(), consumerType, def);
527   use.set(def);
528   return FusionInfo{producerOp, fusedProducer};
529 }
530 
531 /// Prune all dimensions that are of reduction iterator type from `map`.
pruneReductionDimsFromMap(ArrayRef<Attribute> iteratorTypes,AffineMap map)532 static AffineMap pruneReductionDimsFromMap(ArrayRef<Attribute> iteratorTypes,
533                                            AffineMap map) {
534   SmallVector<unsigned, 2> projectedDims;
535   for (auto attr : llvm::enumerate(iteratorTypes)) {
536     if (!isParallelIterator(attr.value()))
537       projectedDims.push_back(attr.index());
538   }
539   return getProjectedMap(map, projectedDims);
540 }
541 
542 /// Returns the mapping from iterations in the consumer that write to the same
543 /// location as the iterations in the producer. To do so use
544 /// - indexing map of the fused view in the consumer : consumerIndexMap
545 /// - indexing map of the fused view in the producer : producerIndexMap
546 ///     consumerLoopToProducerLoop =
547 ///       inverse(producerIndexMap).compose(consumerIndexMap)
getConsumerLoopToProducerLoopMap(LinalgDependenceGraph::LinalgDependenceGraphElem dependence)548 static Optional<AffineMap> getConsumerLoopToProducerLoopMap(
549     LinalgDependenceGraph::LinalgDependenceGraphElem dependence) {
550   auto producer = cast<LinalgOp>(dependence.dependentOpView.op);
551   AffineMap producerIndexingMap =
552       producer.getIndexingMap(dependence.dependentOpView.operandIndex);
553   auto consumer = cast<LinalgOp>(dependence.indexingOpView.op);
554   AffineMap consumerIndexingMap =
555       consumer.getIndexingMap(dependence.indexingOpView.operandIndex);
556 
557   AffineMap prunedProducerIndexingMap = pruneReductionDimsFromMap(
558       producer.iterator_types().getValue(), producerIndexingMap);
559   if (!prunedProducerIndexingMap.isPermutation())
560     return None;
561 
562   if (consumerIndexingMap.getNumResults() !=
563       prunedProducerIndexingMap.getNumResults())
564     return None;
565 
566   LLVM_DEBUG({
567     llvm::dbgs() << "\t producerMap : ";
568     producerIndexingMap.print(llvm::dbgs());
569     llvm::dbgs() << "  pruned : ";
570     prunedProducerIndexingMap.print(llvm::dbgs());
571     llvm::dbgs() << "\n";
572     llvm::dbgs() << "\t consumerMap : ";
573     consumerIndexingMap.print(llvm::dbgs());
574     llvm::dbgs() << "\n";
575   });
576 
577   AffineMap invProducerIndexMap = inversePermutation(prunedProducerIndexingMap);
578   if (!invProducerIndexMap)
579     return None;
580 
581   return invProducerIndexMap.compose(consumerIndexingMap);
582 }
583 
584 /// Given a projected permutation `map`, returns true if the map changes the
585 /// order in which the fused loop dimension appear.
doesTransposeAccess(AffineMap map,const std::set<unsigned> & fusableLoops)586 static bool doesTransposeAccess(AffineMap map,
587                                 const std::set<unsigned> &fusableLoops) {
588   Optional<unsigned> lastFusableLoop;
589   for (unsigned pos : llvm::map_range(map.getResults(), [](AffineExpr expr) {
590          return expr.cast<AffineDimExpr>().getPosition();
591        })) {
592     if (!fusableLoops.count(pos))
593       continue;
594     if (!lastFusableLoop) {
595       lastFusableLoop = pos;
596       continue;
597     }
598     if (pos <= lastFusableLoop.getValue())
599       return true;
600     lastFusableLoop = pos;
601   }
602   return false;
603 }
604 
605 /// Returns the positions of the loop in `op` that can be tiled based on the
606 /// operations that are to be fused with it. For example, in a
607 ///
608 ///   linalg.matmul ins(%a, %b : ...) outs(%c : ...)
609 ///
610 /// if the producer of %a needs to be fused with this op, only the `i` loop of
611 /// the matmul can be tiled while fusing. If producer of %a, and %b are to be
612 /// fused, then no loops can be tiled while fusing. The conditions used are:
613 /// 1. Only parallel loops can be used for tile + fuse. Find the number of
614 ///    common outer parallel loops between the op and its producers being fused.
615 /// 2. Of the parallel loops only some can be fused. Only those loops can be
616 ///    fused such where the fusable loops iteration space only touches one tile
617 ///    of the fused operation. This is because the producer (which is writing
618 ///    the fused subview) has update semantics.
619 ///
620 /// Since an inverse computation is needed, we need to consider the projection
621 /// of the producerIndexMap w.r.t the parallel loops.  The actual fusable loops
622 /// are the dimensions of the consumerLoopToProducerLoop map that correspond to
623 /// parallel loops and appear in the result of the map
624 ///
625 /// Example 1:
626 ///   linalg.fill(%c, %cst)
627 ///   linalg.matmul ins(%a, %b) outs(%c)
628 ///     Number of parallel loops : 2
629 ///     producerIndexMap = affine_map<(i, j) ->(i , j)>
630 ///     consumerIndexMap = affine_map<(i, j, k) -> (i, j)>
631 ///     consumerLoopToProducerLoop = affine_map<(i, j, k) -> (i, j)>
632 ///     Fused dimensions : i, j
633 ///
634 /// Example 2:
635 ///   linalg.matmul ins(%a, %b) outs(%c)
636 ///   linalg.generic {indexing_maps = [affine_map<(i, j) -> (j, i)>, ...
637 ///                   iterator_types = ["parallel", "parallel"]}
638 ///     ins(%c) ...
639 ///
640 ///     Number of parallel loops = 2:
641 ///     producerIndexMap (projected to parallel loops) =
642 ///       affine_map<(i, j) -> (i, j)>
643 ///     consumerLoopToProducerLoop2 = affine_map<(i, j) -> (j, i)>
644 ///     Fused dimensions : i, j
645 ///
646 /// Example 3:
647 ///   linalg.copy(%s, %b)
648 ///   linalg.matmul ins(%a, %b) outs(%c)
649 ///
650 ///   Number of parallel loops = 2
651 ///   produceIndexMap : affine_map<(i, j) -> (i, j)>
652 ///   consumerLoopToProduceLoops = affine_map<(i, j, k) -> (k, j)>
653 ///     submap with only parallel loops = affine_map<(i, j) -> (j)>
654 ///   Fused dimensions : j
655 static std::set<unsigned>
collectFusableLoops(ArrayRef<LinalgOp> ops,const FusableOpDependencesTy & fusableDependences)656 collectFusableLoops(ArrayRef<LinalgOp> ops,
657                     const FusableOpDependencesTy &fusableDependences) {
658   assert(!ops.empty());
659   auto getNumOuterParallelLoops = [](LinalgOp linalgOp) {
660     return linalgOp.iterator_types()
661         .getValue()
662         .take_while([](Attribute attr) -> bool {
663           return attr.cast<StringAttr>().getValue() ==
664                  getParallelIteratorTypeName();
665         })
666         .size();
667   };
668 
669   size_t numOuterParallelLoops = getNumOuterParallelLoops(ops.back());
670   for (auto op : ops.drop_back()) {
671     numOuterParallelLoops =
672         std::min(numOuterParallelLoops, getNumOuterParallelLoops(op));
673   }
674 
675   std::set<unsigned> fusableLoops;
676   auto range = llvm::seq<unsigned>(0, numOuterParallelLoops);
677   fusableLoops.insert(range.begin(), range.end());
678 
679   for (auto op : reverse(ops)) {
680     for (auto dependence : fusableDependences.lookup(op)) {
681       LLVM_DEBUG({
682         llvm::dbgs() << "\t fusable :";
683         for (unsigned i : fusableLoops)
684           llvm::dbgs() << " " << i;
685         llvm::dbgs() << "\n";
686       });
687 
688       Optional<AffineMap> consumerLoopToProducerLoop =
689           getConsumerLoopToProducerLoopMap(dependence);
690       if (!consumerLoopToProducerLoop) {
691         op.emitRemark("failed to get map from consumer loop to producer loop");
692         return {};
693       }
694       // todo: This condition is only an implementation limitation. When fusing
695       // the operation, if the accesses in the producer/consumer are transposes
696       // of each other, the loop bounds for the tiled producer can be
697       // manipulated accordingly. This requires some additional bookkeeping in
698       // the implementation of tile+fuse that is defered to later.
699       if (doesTransposeAccess(*consumerLoopToProducerLoop, fusableLoops)) {
700         op.emitRemark("unhandled fusion when fusion requires permutation");
701         return {};
702       }
703 
704       std::set<unsigned> candidates;
705       for (AffineExpr expr : consumerLoopToProducerLoop->getResults()) {
706         unsigned position = expr.cast<AffineDimExpr>().getPosition();
707         if (fusableLoops.count(position))
708           candidates.insert(position);
709       }
710       LLVM_DEBUG({
711         llvm::dbgs() << "\t candidates :";
712         for (unsigned i : candidates)
713           llvm::dbgs() << " " << i;
714         llvm::dbgs() << "\n";
715       });
716       if (candidates.empty())
717         return {};
718       std::swap(candidates, fusableLoops);
719     }
720   }
721 
722   return fusableLoops;
723 }
724 
725 /// Find all dependences that are fusable.
findAllFusableDependences(ArrayRef<LinalgOp> ops,const LinalgDependenceGraph & dependenceGraph)726 FusableOpDependencesTy mlir::linalg::findAllFusableDependences(
727     ArrayRef<LinalgOp> ops, const LinalgDependenceGraph &dependenceGraph) {
728   FusableOpDependencesTy fusableDependences;
729   // TODO: Currently fusion would not be legal if the fusable dependence is to
730   // the same producer but different indexing map in the consumer. Fix this, but
731   // in the meanwhile disallow such a fusion.
732   DenseMap<Operation *, AffineMap> fusedProducerIndexingMap;
733   for (LinalgOp op : reverse(ops)) {
734     for (auto operandIndex :
735          llvm::seq<unsigned>(0, op.getNumInputsAndOutputBuffers())) {
736       Optional<LinalgDependenceGraph::LinalgDependenceGraphElem>
737           fusableDependence =
738               findFusableProducer(op, operandIndex, dependenceGraph);
739       if (!fusableDependence)
740         continue;
741       LinalgOp producerOp =
742           cast<LinalgOp>(fusableDependence->dependentOpView.op);
743       // Do not fuse dependences that are to operations not in the same basic
744       // block. This avoid moving fused operations across loops that might
745       // themselves carry dependency making the fusion illegal.
746       if (producerOp->getBlock() != op->getBlock()) {
747         op.emitRemark("unhandled fusion of ops in different basic blocks");
748         return FusableOpDependencesTy{};
749       }
750       // Make sure that the indexing map of the view used for fusion in the
751       // producer is a projected permutation.
752       unsigned producerIdx = fusableDependence->dependentOpView.operandIndex;
753       AffineMap producerMap = producerOp.getIndexingMap(producerIdx);
754       if (!producerMap.isProjectedPermutation()) {
755         op.emitRemark(
756             "unhandled non permutation indexing map for fused view in "
757             "producer for operand at index ")
758             << operandIndex;
759         return FusableOpDependencesTy{};
760       }
761 
762       unsigned consumerIdx = fusableDependence->indexingOpView.operandIndex;
763       AffineMap consumerMap = op.getIndexingMap(consumerIdx);
764       if (!consumerMap.isProjectedPermutation()) {
765         op.emitRemark(
766             "unhandled case where indexing map for fused view in the consumer "
767             "is "
768             "not a projected permuration while fusing at index ")
769             << operandIndex;
770         return FusableOpDependencesTy{};
771       }
772 
773       // Check if the producer is already a fusion candidate. Cannot fuse this
774       // dependence if it has a different indexing map when used in the
775       // consumer.
776       if (fusedProducerIndexingMap.count(producerOp.getOperation()) &&
777           fusedProducerIndexingMap[producerOp.getOperation()] != consumerMap) {
778         op.emitRemark(
779             "unhandled fusion to the same producer but with different "
780             "indexing maps");
781         return FusableOpDependencesTy{};
782       }
783       fusedProducerIndexingMap[producerOp.getOperation()] = consumerMap;
784 
785       fusableDependences[producerOp.getOperation()].push_back(
786           *fusableDependence);
787     }
788   }
789   return fusableDependences;
790 }
791 
792 /// Tile the fused loops in the root operation, by setting the tile sizes for
793 /// all other loops to zero (those will be tiled later).
tileRootOperation(OpBuilder & builder,LinalgOp op,ArrayRef<Value> tileSizeVector,const LinalgTilingOptions & options,const std::set<unsigned> & fusedLoops)794 static Optional<TiledLinalgOp> tileRootOperation(
795     OpBuilder &builder, LinalgOp op, ArrayRef<Value> tileSizeVector,
796     const LinalgTilingOptions &options, const std::set<unsigned> &fusedLoops) {
797   SmallVector<Value, 4> tileSizes(tileSizeVector.begin(), tileSizeVector.end());
798   auto zero = std_constant_index(0);
799   for (unsigned i = 0, e = tileSizes.size(); i != e; ++i)
800     if (!fusedLoops.count(i))
801       tileSizes[i] = zero;
802   LinalgTilingOptions tileFusedLoopsOptions = options;
803   tileFusedLoopsOptions.setTileSizes(tileSizes);
804   return tileLinalgOp(builder, op, tileFusedLoopsOptions);
805 }
806 
807 /// Fuse the operations in `fusionCandidates` with `tiledOp`. Latter is expected
808 /// to be a tiled operation such that it is valid to fuse all operations in
809 /// `fusionCandidates`, i.e. move the operation within the inter-tile loops of
810 /// `tiledOp`.
811 static SmallVector<LinalgOp, 1>
fuseOperations(OpBuilder & builder,LinalgOp tiledOp,ArrayRef<LinalgOp> fusionCandidates,const FusableOpDependencesTy & fusableDependences,const std::set<unsigned> & fusedLoops)812 fuseOperations(OpBuilder &builder, LinalgOp tiledOp,
813                ArrayRef<LinalgOp> fusionCandidates,
814                const FusableOpDependencesTy &fusableDependences,
815                const std::set<unsigned> &fusedLoops) {
816   OpBuilder::InsertionGuard guard(builder);
817   builder.setInsertionPoint(tiledOp);
818   DenseMap<unsigned, Range> fusedLoopsAndRanges;
819   for (unsigned loop : fusedLoops) {
820     ShapeDimension shapeDim = getShapeDefiningLoopRange(tiledOp, loop, true);
821     fusedLoopsAndRanges[loop] = getRangeFromOperandShape(
822         builder, tiledOp.getLoc(), shapeDim.shape, shapeDim.dimension);
823   }
824 
825   SmallVector<LinalgOp, 1> fusedOps(fusionCandidates.size());
826   for (auto candidate : enumerate(llvm::reverse(fusionCandidates))) {
827     LinalgOp fusedOp = fuse(builder, candidate.value(), fusedLoopsAndRanges);
828     fusedOps[fusionCandidates.size() - candidate.index() - 1] = fusedOp;
829     builder.setInsertionPoint(fusedOp);
830   }
831   return fusedOps;
832 }
833 
834 template <typename LoopType>
835 static Optional<TiledAndFusedLinalgOps>
tileAndFuseLinalgOpsImpl(OpBuilder & builder,ArrayRef<LinalgOp> ops,const LinalgDependenceGraph & dependenceGraph,const LinalgTilingOptions & tilingOptions)836 tileAndFuseLinalgOpsImpl(OpBuilder &builder, ArrayRef<LinalgOp> ops,
837                          const LinalgDependenceGraph &dependenceGraph,
838                          const LinalgTilingOptions &tilingOptions) {
839   if (ops.empty())
840     return llvm::None;
841   LinalgOp rootOp = ops.back();
842   for (auto op : enumerate(ops)) {
843     // TODO: Nothing in the fusion of sequence of ops is specific to
844     // buffers. This check can be removed after it is tested on tensors.
845     LinalgOp linalgOp = op.value();
846     if (!linalgOp.hasBufferSemantics()) {
847       linalgOp.emitError("tile and fuse only tested for buffer operation");
848       return llvm::None;
849     }
850   }
851   // TODO: Support interchange with tile + fuse. This might actually help do
852   // better fusion.
853   if (!tilingOptions.interchangeVector.empty()) {
854     rootOp.emitError("unable to handle tile and fuse with interchange");
855     return llvm::None;
856   }
857 
858   OpBuilder::InsertionGuard guard(builder);
859   builder.setInsertionPoint(rootOp);
860   ScopedContext scope(builder, rootOp.getLoc());
861 
862   // Find all the producers.
863   FusableOpDependencesTy fusableDependences =
864       findAllFusableDependences(ops, dependenceGraph);
865   if (fusableDependences.empty())
866     return llvm::None;
867 
868   TiledAndFusedLinalgOps ret;
869   // Find the loops that can be tiled and fused.
870   ret.fusedLoopDims = collectFusableLoops(ops, fusableDependences);
871 
872   // If there are no fusable dependences or there are no tile+fusable loops,
873   // just return.
874   if (ret.fusedLoopDims.empty()) {
875     return llvm::None;
876   }
877 
878   // Tile the fused loops in the last operation in the list.
879   SmallVector<Value, 4> tileSizeVector =
880       tilingOptions.tileSizeComputationFunction(builder, rootOp);
881   Optional<TiledLinalgOp> tiledRootOp = tileRootOperation(
882       builder, rootOp, tileSizeVector, tilingOptions, ret.fusedLoopDims);
883   if (!tiledRootOp) {
884     rootOp.emitError("failed to tile the fused loops");
885     return llvm::None;
886   }
887   ret.op = tiledRootOp->op;
888   ret.fusedLoops.assign(tiledRootOp->loops.begin(), tiledRootOp->loops.end());
889 
890   // Fuse the other operations into the fused inter-tile loops produced above.
891   ret.fusedProducers = fuseOperations(builder, ret.op, ops.drop_back(),
892                                       fusableDependences, ret.fusedLoopDims);
893   return ret;
894 }
895 
896 Optional<TiledAndFusedLinalgOps>
tileAndFuseLinalgOps(OpBuilder & builder,ArrayRef<LinalgOp> ops,const LinalgDependenceGraph & dependenceGraph,const LinalgTilingOptions & tilingOptions)897 mlir::linalg::tileAndFuseLinalgOps(OpBuilder &builder, ArrayRef<LinalgOp> ops,
898                                    const LinalgDependenceGraph &dependenceGraph,
899                                    const LinalgTilingOptions &tilingOptions) {
900   switch (tilingOptions.loopType) {
901   case LinalgTilingLoopType::Loops:
902     return tileAndFuseLinalgOpsImpl<scf::ForOp>(builder, ops, dependenceGraph,
903                                                 tilingOptions);
904   case LinalgTilingLoopType::ParallelLoops:
905     return tileAndFuseLinalgOpsImpl<scf::ParallelOp>(
906         builder, ops, dependenceGraph, tilingOptions);
907   default:;
908   }
909   return llvm::None;
910 }
911