• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- Fusion.cpp - Implementation of linalg Fusion -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the linalg dialect Fusion on tensors operations pass.
10 //
11 //===----------------------------------------------------------------------===//
12 #include "PassDetail.h"
13 #include "mlir/Dialect/Affine/IR/AffineOps.h"
14 #include "mlir/Dialect/Linalg/IR/LinalgOps.h"
15 #include "mlir/Dialect/Linalg/IR/LinalgTypes.h"
16 #include "mlir/Dialect/Linalg/Passes.h"
17 #include "mlir/Dialect/Linalg/Transforms/Transforms.h"
18 #include "mlir/Dialect/Linalg/Utils/Utils.h"
19 #include "mlir/IR/AffineExpr.h"
20 #include "mlir/IR/AffineMap.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "mlir/Support/LLVM.h"
23 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
24 
25 using namespace mlir;
26 using namespace mlir::linalg;
27 
28 /// Implementation of fusion of generic ops and indexed_generic ops.
29 // struct FuseGenericOpsOnTensors {
areTensorOpsFusable(LinalgOp producer,LinalgOp consumer,unsigned consumerIdx)30 static bool areTensorOpsFusable(LinalgOp producer, LinalgOp consumer,
31                                 unsigned consumerIdx) {
32   // Producer and consumer must have tensor semantics.
33   if (!producer.hasTensorSemantics() || !consumer.hasTensorSemantics())
34     return false;
35 
36   // Verify that
37   // - the producer has all "parallel" iterator type.
38   if (producer.getNumParallelLoops() != producer.getNumLoops())
39     return false;
40 
41   // Get the consumer index map. The number of results of the consumer index
42   // map must match the number of loops of the producer.
43   AffineMap consumerIndexMap = consumer.getIndexingMap(consumerIdx);
44   if (consumerIndexMap.getNumResults() != producer.getNumLoops())
45     return false;
46 
47   // Finally the index_map for the result must be invertible. For now just
48   // verify it is a permutation.
49   AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
50   return producerResultIndexMap.isPermutation();
51 }
52 
53 /// Append to `fusedOpIndexingMapAttrs` the indexing maps for the operands of
54 /// the `producer` to use in the fused operation given the indexing map of the
55 /// result of the producer in the consumer.
getIndexingMapOfProducerOperandsInFusedOp(LinalgOp producer,AffineMap fusedConsumerArgIndexMap,SmallVectorImpl<Attribute> & fusedOpIndexingMapAttrs)56 static void getIndexingMapOfProducerOperandsInFusedOp(
57     LinalgOp producer, AffineMap fusedConsumerArgIndexMap,
58     SmallVectorImpl<Attribute> &fusedOpIndexingMapAttrs) {
59   // The indexing map in the consumer op (fusedConsumerArgIndexMap) is a map
60   // from consumer loop -> consumer arg tensor index/producer result tensor
61   // index. The fused loop is same as the consumer loop. For each producer arg
62   // the indexing map to be computed is a map from consumer loop -> producer
63   // arg tensor index.
64 
65   AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
66   // producerResultIndexMap is a map from producer loop -> tensor index.
67   // Compute the inverse to get map from tensor index -> producer loop.
68   // The inverse is a map from producer result tensor index -> producer loop.
69   AffineMap invProducerResultIndexMap =
70       inversePermutation(producerResultIndexMap);
71   assert(invProducerResultIndexMap &&
72          "expected producer result indexig map to be invertible");
73   for (unsigned argNum : llvm::seq<unsigned>(0, producer.getNumInputs())) {
74     // argMap is a map from producer loop -> producer arg tensor index.
75     AffineMap argMap = producer.getInputIndexingMap(argNum);
76 
77     // Compose argMap with invProducerResultIndexMap to get a map from
78     // producer result tensor index -> producer arg tensor index.
79     AffineMap t1 = argMap.compose(invProducerResultIndexMap);
80 
81     // Compose t1 with fusedConsumerArgIndexMap gives an indexing map from
82     // consumer loop/ fused loop -> producer arg tensor index.
83     AffineMap indexingMap = t1.compose(fusedConsumerArgIndexMap);
84     fusedOpIndexingMapAttrs.push_back(AffineMapAttr::get(indexingMap));
85   }
86 }
87 
88 /// Generate the region of the fused tensor operation. The region of the fused
89 /// op must be empty.
generateFusedTensorOpRegion(PatternRewriter & rewriter,Operation * fusedOp,LinalgOp producer,LinalgOp consumer,AffineMap consumerToProducerLoopsMap,unsigned consumerIdx,unsigned nloops)90 static void generateFusedTensorOpRegion(PatternRewriter &rewriter,
91                                         Operation *fusedOp, LinalgOp producer,
92                                         LinalgOp consumer,
93                                         AffineMap consumerToProducerLoopsMap,
94                                         unsigned consumerIdx, unsigned nloops) {
95   // Build the region of the fused op.
96   Block &producerBlock = producer->getRegion(0).front();
97   Block &consumerBlock = consumer->getRegion(0).front();
98   Block *fusedBlock = new Block();
99   fusedOp->getRegion(0).push_back(fusedBlock);
100   BlockAndValueMapping mapper;
101   OpBuilder::InsertionGuard guard(rewriter);
102   rewriter.setInsertionPointToStart(fusedBlock);
103 
104   // The block arguments are
105   // [index_0, index_1, ... ,
106   //   consumer_operand_0, ... , consumer_operand_(`consumerIdx`-1),
107   //   producer_operand_0, ... , producer_operand_(n-1)],
108   //   consumer_operand_(`consumerIdx`), .. consumer_operand_(m-1)]
109   // , where n is the number of producer's operand and m is the number
110   // consumer's operand.
111   // If both `numProducerIndices` and `numConsumerIndices` are zero, this is a
112   // generic op. In this case, there are no indices in block arguments.
113   unsigned numProducerIndices = isa<IndexedGenericOp>(producer.getOperation())
114                                     ? producer.getNumLoops()
115                                     : 0;
116   unsigned numConsumerIndices = isa<IndexedGenericOp>(consumer.getOperation())
117                                     ? consumer.getNumLoops()
118                                     : 0;
119   unsigned numFusedOpIndices =
120       (isa<IndexedGenericOp>(producer.getOperation()) ||
121        isa<IndexedGenericOp>(consumer.getOperation()))
122           ? std::max(producer.getNumLoops(), consumer.getNumLoops())
123           : 0;
124   // Firstly, add all the indices to the block arguments.
125   for (unsigned i = 0, e = numFusedOpIndices; i < e; ++i)
126     fusedBlock->addArgument(rewriter.getIndexType());
127   // Map the arguments for the unmodified args from the consumer.
128   for (auto consumerArg : llvm::enumerate(consumerBlock.getArguments())) {
129     if (consumerArg.index() == consumerIdx + numConsumerIndices) {
130       // Map the arguments for the args from the producer.
131       for (auto producerArg : llvm::enumerate(producerBlock.getArguments())) {
132         // If producer is an indexed_generic op, map the indices from consumer
133         // loop to producer loop (because the fusedOp is built based on
134         // consumer's perspective).
135         if (producerArg.index() < numProducerIndices) {
136           auto newIndex = rewriter.create<mlir::AffineApplyOp>(
137               producer.getLoc(),
138               consumerToProducerLoopsMap.getSubMap(producerArg.index()),
139               fusedBlock->getArguments().take_front(numFusedOpIndices));
140           mapper.map(producerArg.value(), newIndex);
141         } else {
142           mapper.map(producerArg.value(),
143                      fusedBlock->addArgument(producerArg.value().getType()));
144         }
145       }
146       continue;
147     }
148 
149     // If consumer is an indexed_generic op, map the indices to the block
150     // arguments directly. Otherwise, add the same type of argument and map to
151     // it.
152     if (consumerArg.index() < numConsumerIndices) {
153       mapper.map(consumerArg.value(),
154                  fusedBlock->getArgument(consumerArg.index()));
155     } else {
156       mapper.map(consumerArg.value(),
157                  fusedBlock->addArgument(consumerArg.value().getType()));
158     }
159   }
160 
161   // Add operations from producer (except the yield operation) to the fused
162   // op.
163   for (auto &op : producerBlock.getOperations()) {
164     if (auto yieldOp = dyn_cast<linalg::YieldOp>(op)) {
165       // Lookup the value the yield operation is mapped to.
166       Value yieldVal = yieldOp.getOperand(0);
167       if (Value clonedVal = mapper.lookupOrNull(yieldVal))
168         mapper.map(consumerBlock.getArgument(consumerIdx + numConsumerIndices),
169                    clonedVal);
170       continue;
171     }
172     rewriter.clone(op, mapper);
173   }
174   for (auto &op : consumerBlock.getOperations())
175     rewriter.clone(op, mapper);
176 }
177 
178 static Optional<SmallVector<Value, 1>>
fuseTensorOpsImpl(LinalgOp producer,LinalgOp consumer,unsigned consumerIdx,PatternRewriter & rewriter)179 fuseTensorOpsImpl(LinalgOp producer, LinalgOp consumer, unsigned consumerIdx,
180                   PatternRewriter &rewriter) {
181   if (!areTensorOpsFusable(producer, consumer, consumerIdx))
182     return llvm::None;
183 
184   unsigned numFusedOperands =
185       producer.getNumInputs() + consumer.getNumInputs() - 1;
186 
187   // Compute the fused operands list,
188   SmallVector<Value, 2> fusedOperands;
189   fusedOperands.reserve(numFusedOperands);
190   auto consumerOperands = consumer.getInputs();
191   auto producerOperands = producer.getInputs();
192   fusedOperands.assign(consumerOperands.begin(),
193                        std::next(consumerOperands.begin(), consumerIdx));
194   fusedOperands.append(producerOperands.begin(), producerOperands.end());
195   fusedOperands.append(std::next(consumerOperands.begin(), consumerIdx + 1),
196                        consumerOperands.end());
197 
198   // Compute indexing_maps for the fused operation. The indexing_maps for the
199   // operands of the consumers that arent fused are the same. The
200   // indexing_maps for the producers need to be computed based on the
201   // indexing_map of the operand at consumerIdx in the consumer.
202   SmallVector<Attribute, 4> fusedIndexMaps;
203   auto consumerIndexMaps = consumer.indexing_maps();
204   fusedIndexMaps.reserve(fusedOperands.size() + consumer.getNumOutputs());
205   fusedIndexMaps.assign(consumerIndexMaps.begin(),
206                         std::next(consumerIndexMaps.begin(), consumerIdx));
207   // Compute indexing maps for the producer args in the fused operation.
208   getIndexingMapOfProducerOperandsInFusedOp(
209       producer, consumer.getInputIndexingMap(consumerIdx), fusedIndexMaps);
210 
211   // Append the indexing maps for the remaining consumer operands.
212   fusedIndexMaps.append(std::next(consumerIndexMaps.begin(), consumerIdx + 1),
213                         consumerIndexMaps.end());
214 
215   // Generate the fused op.
216   // Tensor-level fusion is only on ops without initTensors and outputBuffers.
217   LinalgOp fusedOp;
218   if (isa<GenericOp>(producer.getOperation()) &&
219       isa<GenericOp>(consumer.getOperation())) {
220     fusedOp =
221         rewriter
222             .create<GenericOp>(consumer.getLoc(), consumer->getResultTypes(),
223                                /*inputs=*/fusedOperands,
224                                /*outputBuffers=*/ValueRange{},
225                                /*initTensors=*/ValueRange{},
226                                rewriter.getArrayAttr(fusedIndexMaps),
227                                consumer.iterator_types(),
228                                /*doc=*/nullptr,
229                                /*library_call=*/nullptr,
230                                /*sparse=*/nullptr)
231             .getOperation();
232   } else {
233     fusedOp = rewriter
234                   .create<IndexedGenericOp>(
235                       consumer.getLoc(), consumer->getResultTypes(),
236                       /*inputs=*/fusedOperands,
237                       /*outputBuffers=*/ValueRange{},
238                       /*initTensors=*/ValueRange{},
239                       rewriter.getArrayAttr(fusedIndexMaps),
240                       consumer.iterator_types(),
241                       /*doc=*/nullptr,
242                       /*library_call=*/nullptr,
243                       /*sparse=*/nullptr)
244                   .getOperation();
245   }
246 
247   // Construct an AffineMap from consumer loops to producer loops.
248   // consumer loop -> tensor index
249   AffineMap consumerResultIndexMap = consumer.getInputIndexingMap(consumerIdx);
250   // producer loop -> tensor index
251   AffineMap producerResultIndexMap = producer.getOutputIndexingMap(0);
252   // tensor index -> producer loop
253   AffineMap invProducerResultIndexMap =
254       inversePermutation(producerResultIndexMap);
255   assert(invProducerResultIndexMap &&
256          "expected producer result indexig map to be invertible");
257   // consumer loop -> producer loop
258   AffineMap consumerToProducerLoopsMap =
259       invProducerResultIndexMap.compose(consumerResultIndexMap);
260 
261   generateFusedTensorOpRegion(rewriter, fusedOp.getOperation(), producer,
262                               consumer, consumerToProducerLoopsMap, consumerIdx,
263                               consumer.getNumLoops());
264   return SmallVector<Value, 1>(fusedOp->getResults());
265 }
266 
267 /// Linearize the expressions in `sourceMap` based on the `reassociationMaps`
268 /// provided, given the shape of the source tensor that corresponds to the
269 /// `sourceMap`. Note that this implicitly assumes that the tensors dimensions
270 /// are "row-major" ordered logically.
271 ///
272 /// For example:
273 ///
274 /// %0 = op ... : tensor<?x?x4x5xf32>
275 /// with output index_map `affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>`
276 ///
277 /// and reshape:
278 /// %1 = linalg.tensor_reshape %0 [affine_map<(i, j, k, l) -> (i)>,
279 ///                                affine_map<(i, j, k, l) -> (j, k, l)>] :
280 ///        tensor<?x?x4x5xf32> into tensor<?x?xf32>
281 ///
282 /// would be rewritten into:
283 /// %0 = op ... : tensor<?x?x4x5xf32>
284 /// with output index_map
285 ///   `affine_map<(d0, d1, d2, d3) -> (d0, d1 * 20 + d2 * 5 + d3)>`
linearizeCollapsedDims(AffineMap sourceMap,ArrayRef<int64_t> sourceShape,ArrayRef<AffineMap> reassociationMaps)286 static AffineMap linearizeCollapsedDims(AffineMap sourceMap,
287                                         ArrayRef<int64_t> sourceShape,
288                                         ArrayRef<AffineMap> reassociationMaps) {
289   SmallVector<AffineExpr, 4> resultExprs;
290   resultExprs.reserve(reassociationMaps.size());
291   ArrayRef<AffineExpr> sourceExprs = sourceMap.getResults();
292   MLIRContext *context = sourceMap.getContext();
293 
294   // Compute the result exprs based on the reassociation maps.
295   for (AffineMap map : reassociationMaps) {
296     ArrayRef<AffineExpr> collapsedDims = map.getResults();
297     // Assume that they are in-order and contiguous (already checked in
298     // verifier).
299     assert(!collapsedDims.empty());
300     unsigned startDim =
301         collapsedDims.front().cast<AffineDimExpr>().getPosition();
302     AffineExpr linearizedExpr = makeCanonicalStridedLayoutExpr(
303         sourceShape.slice(startDim, collapsedDims.size()),
304         sourceExprs.slice(startDim, collapsedDims.size()), context);
305     resultExprs.push_back(linearizedExpr);
306   }
307   return AffineMap::get(sourceMap.getNumDims(), sourceMap.getNumSymbols(),
308                         resultExprs, context);
309 }
310 
311 /// Checks if the `reshapeOp` can be fused with it consumer (if `asProducer` is
312 /// true) or its producer (if `asProducer` is false) given the indexing map at
313 /// its use.
isTensorReshapeOpFoldableByLinearization(TensorReshapeOp reshapeOp,AffineMap useIndexMap,bool asProducer)314 static bool isTensorReshapeOpFoldableByLinearization(TensorReshapeOp reshapeOp,
315                                                      AffineMap useIndexMap,
316                                                      bool asProducer) {
317   RankedTensorType returnType = reshapeOp.getResultType();
318   RankedTensorType operandType = reshapeOp.getSrcType();
319   // Reshape is fusable with its consumer (i.e. reshape as a producer) when its
320   // operand is of lesser rank than the result. Fusing when operand has higher
321   // rank will require use of mods and divs in the indexing maps of the fused op
322   // which would make it non-invertible. Similarly reshape is fused with its
323   // producer (i.e. reshape as consumer) only if the return type has lesser
324   // rank.
325   if ((asProducer && reshapeOp.getSrcType().hasStaticShape() &&
326        returnType.getRank() < operandType.getRank()) ||
327       (!asProducer && reshapeOp.getResultType().hasStaticShape() &&
328        operandType.getRank() < returnType.getRank()))
329     return false;
330   return useIndexMap.isPermutation();
331 }
332 
333 /// Based on the type of `op` create a linalg op of the same type, i.e. if `op`
334 /// is a linalg.generic operation, the create a `linalg.generic` operation with
335 /// the given `args`. Expects `op` to be `linalg.generic` or
336 /// `linalg.indexed_generic`.
337 template <typename... Args>
createLinalgOpOfSameType(LinalgOp op,PatternRewriter & rewriter,Args...args)338 static LinalgOp createLinalgOpOfSameType(LinalgOp op, PatternRewriter &rewriter,
339                                          Args... args) {
340   if (isa<GenericOp>(op.getOperation()))
341     return rewriter.create<GenericOp>(args...);
342   if (isa<IndexedGenericOp>(op.getOperation()))
343     return rewriter.create<IndexedGenericOp>(args...);
344   llvm_unreachable(
345       "expected only linalg.generic or linalg.indexed_generic ops");
346   return nullptr;
347 }
348 
349 /// Conditions for folding a generic/indexed-generic operation with a reshape op
350 /// by expanding the iteration space dimensionality for tensor operations. These
351 /// are preconditions assumed by `foldReshapeByDimExpansion` which implements
352 /// the following fusion pattern.
353 ///
354 ///  Consider
355 ///
356 ///  %c = linalg.generic ins(%a, %b : memref<?x?x?xf32>, memref<?x?xf32>)
357 ///         indexing_maps = [affine_map<(d0, d1, d2) -> (d1, d0, d2)>,
358 ///                          affine_map<(d0, d1, d2) -> (d1, d2)>,
359 ///                          affine_map<(d0, d1, d2) -> (d0, d2, d1)>]
360 ///  %d = linalg.tensor_reshape %c
361 ///         [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1)>,
362 ///          affine_map<(d0, d1, d2, d3, d4, d5) -> (d2)>,
363 ///          affine_map<(d0, d1, d2, d3, d4, d5) -> (d3, d4, d5)>]
364 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
365 ///
366 ///  The reshape can be folded into the `linalgOp` if the
367 ///  generic/indexed-generic op loop dimensionality is increased to match the
368 ///  result (operand) of the tensor_reshape when the reshape is expanding
369 ///  (folding). The indexing_map of the fused tensor in the `linalgOp` and the
370 ///  reassociation map helps compute the indexing maps of the modified op. For
371 ///  the above example, based on the reassociation map it can be concluded that
372 ///
373 ///  - The loop used to access the first dimension of the fused tensor is split
374 ///    into two.
375 ///  - The loop used to access the second dimension of the fused tensor is kept
376 ///    as is.
377 ///  - The loop used to access the third dimension of the fused tensor is split
378 ///    into three.
379 ///
380 ///  i.e. (e0, e1, e2, e3, e4) is the domain of the indexing map of the modified
381 ///  op, then
382 ///
383 ///   d0 -> e0, e1
384 ///   d1 -> e2, e3, e4
385 ///   d2 -> e5
386 ///
387 ///  substituting this, the generic op can be rewritten as
388 ///
389 ///  %d = linalg.generic ins(%0, %1 : )
390 ///        indexing_maps =
391 ///         [affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e0, e1, e5)>,
392 ///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e2, e3, e4, e5)>,
393 ///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e5, e2, e3, e4)>]
394 ///
395 ///  Since operands to the linalg generic are now 5D, reshapes can be introduced
396 ///  to make it consistent
397 ///
398 ///  %0 = linalg.tensor_reshape %a
399 ///         [affine_map<(e0, e1, e2, e3, e4, e5) -> (e0, e1, e2),
400 ///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e3, e4),
401 ///          affine_map<(e0, e1, e2, e3, e4, e5) -> (e5)]
402 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?x?x?xf32>
403 ///  %1 = linalg.tensor_reshape %b
404 ///         [affine_map<(e0, e1, e2, e3) -> (e0, e1, e2),
405 ///          affine_map<(e0, e1, e2, e3) -> (e3)]
406 ///       : tensor<?x?x?xf32> into tensor<?x?x?x?xf32>
407 ///
408 ///  The added reshapes are again expanding patterns, so they will get fused
409 ///  with its producers if possible.
isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,unsigned fusedTensorIndex)410 static bool isFusableWithReshapeByDimExpansion(LinalgOp linalgOp,
411                                                unsigned fusedTensorIndex) {
412   // Is fusable only if:
413   // - The linalgOp is a generic op, or an indexed_generic.
414   // - All the indexing maps for operands and results in linalgOp are projected
415   //   permutations.
416   // - The fused tensor is not a scalar.
417   // - All the loops in linalgOp are parallel loops.
418   return isa<GenericOp, IndexedGenericOp>(linalgOp.getOperation()) &&
419          linalgOp.hasTensorSemantics() &&
420          llvm::all_of(linalgOp.indexing_maps().getValue(),
421                       [](Attribute attr) {
422                         return attr.cast<AffineMapAttr>()
423                             .getValue()
424                             .isProjectedPermutation();
425                       }) &&
426          linalgOp.getIndexingMap(fusedTensorIndex).getNumResults() > 0 &&
427          llvm::all_of(linalgOp.iterator_types(), [](Attribute attr) {
428            return attr.cast<StringAttr>().getValue() ==
429                   getParallelIteratorTypeName();
430          });
431 }
432 
433 /// Implements the fusion of a tensor_reshape op and a generic/indexed_generic
434 /// op as explained in `isFusableWithReshapeByExpansion`. Assumes that those
435 /// conditions have been satisfied.
436 static Optional<SmallVector<Value, 1>>
fuseWithReshapeByExpansion(LinalgOp linalgOp,TensorReshapeOp reshapeOp,unsigned fusedTensorIndex,PatternRewriter & rewriter)437 fuseWithReshapeByExpansion(LinalgOp linalgOp, TensorReshapeOp reshapeOp,
438                            unsigned fusedTensorIndex,
439                            PatternRewriter &rewriter) {
440   assert(isFusableWithReshapeByDimExpansion(linalgOp, fusedTensorIndex) &&
441          "preconditions for fuse operation failed");
442   // Check if reshape is expanding or collapsing.
443   bool isExpanding =
444       reshapeOp.getSrcType().getRank() < reshapeOp.getResultType().getRank();
445   RankedTensorType expandedType =
446       isExpanding ? reshapeOp.getResultType() : reshapeOp.getSrcType();
447   AffineMap fusedIndexMap = linalgOp.getIndexingMap(fusedTensorIndex);
448 
449   // The reshape is folding/expanding consecutive dimensions. Given the indexing
450   // map of the fused tensor find the number of dimensions each of the loops of
451   // the original op is expanded into. Also record the shape of the expanded
452   // dimensions.
453   ArrayRef<int64_t> expandedShape = expandedType.getShape();
454   Optional<SmallVector<int64_t, 4>> origOpLoopRange =
455       getStaticLoopRanges(linalgOp);
456   if (!origOpLoopRange) {
457     linalgOp.emitError("unable to find loop range for operation");
458     return llvm::None;
459   }
460   SmallVector<unsigned, 4> numFoldedDims(fusedIndexMap.getNumDims(), 1);
461   SmallVector<SmallVector<int64_t, 4>, 4> expandedDimsShape(
462       fusedIndexMap.getNumDims());
463   auto reassociationMaps = reshapeOp.getReassociationMaps();
464   for (auto resultExpr : llvm::enumerate(fusedIndexMap.getResults())) {
465     unsigned pos = resultExpr.value().cast<AffineDimExpr>().getPosition();
466     AffineMap foldedDims = reassociationMaps[resultExpr.index()];
467     numFoldedDims[pos] = foldedDims.getNumResults();
468     ArrayRef<int64_t> shape =
469         expandedShape.slice(foldedDims.getDimPosition(0), numFoldedDims[pos]);
470     expandedDimsShape[pos].assign(shape.begin(), shape.end());
471   }
472   // The remaining dimensions remain the same.
473   for (unsigned i : llvm::seq<unsigned>(0, fusedIndexMap.getNumDims()))
474     if (expandedDimsShape[i].empty())
475       expandedDimsShape[i] = {(*origOpLoopRange)[i]};
476 
477   if (isa<IndexedGenericOp>(linalgOp.getOperation())) {
478     // For indexed generic op, the region contains arguments that represent the
479     // induction variable value of the loops. In the fused op these values are
480     // obtained by linearizing the expanded dimensions. For now just check that
481     // the extents used in the linearization (all the expanded dims except the
482     // front) are statically know. For dynamic case, we would need shape
483     // information on these dimensions to get these.
484     for (auto &expandedShape : expandedDimsShape) {
485       if (expandedShape.size() == 1)
486         continue;
487       for (int64_t expandedDimShape : llvm::make_range(
488                std::next(expandedShape.begin()), expandedShape.end())) {
489         if (ShapedType::isDynamic(expandedDimShape)) {
490           linalgOp.emitError(
491               "unable to fuse indexed generic op where the expanded dim is "
492               "dynamic");
493           return llvm::None;
494         }
495       }
496     }
497   }
498 
499   // The remapping of the indices is then the prefix sum (inclusive) of the
500   // numFoldedDims.
501   SmallVector<unsigned, 4> remapping(numFoldedDims.size() + 1, 0);
502   unsigned sum = 0;
503   for (auto numFoldedDim : llvm::enumerate(numFoldedDims)) {
504     sum += numFoldedDim.value();
505     remapping[numFoldedDim.index() + 1] = sum;
506   }
507 
508   SmallVector<AffineMap, 4> expandedOpIndexingMaps;
509   // Compute the modified indexing maps by replacing every loop (AffineDimExpr)
510   // in the original indexing map with the sequence of loops that it is expanded
511   // to.
512   for (AffineMap indexingMap : linalgOp.getIndexingMaps()) {
513     SmallVector<AffineExpr, 4> newExprs;
514     for (AffineExpr expr : indexingMap.getResults()) {
515       unsigned pos = expr.cast<AffineDimExpr>().getPosition();
516       for (unsigned newPos :
517            llvm::seq<unsigned>(remapping[pos], remapping[pos + 1])) {
518         newExprs.push_back(rewriter.getAffineDimExpr(newPos));
519       }
520     }
521     expandedOpIndexingMaps.push_back(
522         AffineMap::get(remapping.back(), indexingMap.getNumSymbols(), newExprs,
523                        rewriter.getContext()));
524   }
525 
526   // The operands of the expanded op are computed by reshaping the original
527   // operands. The reshape depends on the ordering of the loop used to access
528   // the tensor in the original operation, and are expanded into as many
529   // dimensions as the loop is expanded into (as computed by `remapping`).
530   auto getReshapeInfo =
531       [&](AffineMap operandIndexingMap,
532           SmallVectorImpl<ReassociationIndices> &reassociation,
533           SmallVectorImpl<int64_t> &expandedOpOperandShape) {
534         unsigned reshapeDims = 0;
535         for (AffineExpr expr : operandIndexingMap.getResults()) {
536           unsigned origDim = expr.cast<AffineDimExpr>().getPosition();
537           auto foldedDims = llvm::seq<int64_t>(
538               reshapeDims, reshapeDims + numFoldedDims[origDim]);
539           reassociation.emplace_back(foldedDims.begin(), foldedDims.end());
540           expandedOpOperandShape.append(expandedDimsShape[origDim].begin(),
541                                         expandedDimsShape[origDim].end());
542           reshapeDims += numFoldedDims[origDim];
543         }
544       };
545   SmallVector<Value, 4> expandedOpOperands;
546   for (auto operand : llvm::enumerate(linalgOp.getInputs())) {
547     if (operand.index() == fusedTensorIndex) {
548       expandedOpOperands.push_back(reshapeOp.src());
549       continue;
550     }
551     AffineMap indexingMap = linalgOp.getIndexingMap(operand.index());
552     SmallVector<ReassociationIndices, 4> reassociation;
553     SmallVector<int64_t, 4> expandedOperandShape;
554     getReshapeInfo(indexingMap, reassociation, expandedOperandShape);
555     Type expandedOperandType = RankedTensorType::get(
556         expandedOperandShape,
557         operand.value().getType().cast<ShapedType>().getElementType());
558     if (expandedOperandType != operand.value().getType()) {
559       expandedOpOperands.push_back(rewriter.create<TensorReshapeOp>(
560           linalgOp.getLoc(), expandedOperandType, operand.value(),
561           reassociation));
562     } else {
563       expandedOpOperands.push_back(operand.value());
564     }
565   }
566   SmallVector<Type, 1> resultTypes;
567   SmallVector<SmallVector<ReassociationIndices, 4>, 1> resultReassociation;
568   for (auto result : llvm::enumerate(linalgOp->getResults())) {
569     AffineMap indexingMap =
570         linalgOp.getIndexingMap(linalgOp.getNumInputs() + result.index());
571     SmallVector<ReassociationIndices, 4> reassociation;
572     SmallVector<int64_t, 4> expandedResultShape;
573     getReshapeInfo(indexingMap, reassociation, expandedResultShape);
574     resultTypes.push_back(RankedTensorType::get(
575         expandedResultShape,
576         result.value().getType().cast<ShapedType>().getElementType()));
577     resultReassociation.emplace_back(std::move(reassociation));
578   }
579 
580   // The iterator types of the expanded op are all parallel.
581   SmallVector<StringRef, 4> iteratorTypes(remapping.back(),
582                                           getParallelIteratorTypeName());
583 
584   LinalgOp fusedOp = createLinalgOpOfSameType(
585       linalgOp, rewriter, linalgOp.getLoc(), resultTypes,
586       /*inputs=*/expandedOpOperands,
587       /*outputBuffers=*/ValueRange{},
588       /*initTensors=*/ValueRange{}, expandedOpIndexingMaps, iteratorTypes);
589   Region &fusedRegion = fusedOp->getRegion(0);
590   Region &originalRegion = linalgOp->getRegion(0);
591 
592   if (isa<GenericOp>(linalgOp.getOperation())) {
593     rewriter.cloneRegionBefore(originalRegion, fusedRegion,
594                                fusedRegion.begin());
595   } else {
596     assert(isa<IndexedGenericOp>(linalgOp.getOperation()));
597     // Create an entry block in the fused Region with same number of arguments
598     // as the fused op
599     Block *fusedEntryBlock = new Block;
600     fusedRegion.push_back(fusedEntryBlock);
601     rewriter.cloneRegionBefore(originalRegion, fusedRegion, fusedRegion.end());
602 
603     // Merge the entry block of the fused op with the cloned blocks. For this
604     // compute the value for arguments of the region in the original operation
605     // in terms of the arguments of the fused op. Since the original operation
606     // is expanded, the expanded dimensions need to be folded back to get the
607     // replacement value for the arguments corresponding to interation index.
608     // For now this expects that all the loop ranges are constants, which is
609     // true if the shapes are all static. This has already been checked in the
610     // precondition.
611     using namespace edsc::op;
612     using namespace edsc::intrinsics;
613     OpBuilder::InsertionGuard guard(rewriter);
614     SmallVector<Value, 4> argReplacements(originalRegion.getNumArguments());
615     rewriter.setInsertionPointToStart(fusedEntryBlock);
616     edsc::ScopedContext scopedContext(rewriter, fusedOp.getLoc());
617     IndexType indexType = rewriter.getIndexType();
618     for (unsigned i : llvm::seq<unsigned>(0, numFoldedDims.size())) {
619       Value linearizedIndex = fusedEntryBlock->addArgument(indexType);
620       for (unsigned foldedDim = remapping[i] + 1; foldedDim != remapping[i + 1];
621            foldedDim++) {
622         int64_t expandedDimExtent =
623             expandedDimsShape[i][foldedDim - remapping[i]];
624         assert(!ShapedType::isDynamic(expandedDimExtent));
625         linearizedIndex =
626             linearizedIndex * std_constant_index(expandedDimExtent);
627         linearizedIndex =
628             linearizedIndex + fusedEntryBlock->addArgument(indexType);
629       }
630       argReplacements[i] = linearizedIndex;
631     }
632     for (unsigned i :
633          llvm::seq<unsigned>(numFoldedDims.size(), argReplacements.size())) {
634       argReplacements[i] =
635           fusedEntryBlock->addArgument(originalRegion.getArgument(i).getType());
636     }
637     rewriter.mergeBlocks(fusedEntryBlock->getNextNode(), fusedEntryBlock,
638                          argReplacements);
639   }
640 
641   // Reshape the result values to their original shape if this is a collapsing
642   // reshape folded into its consumer.
643   SmallVector<Value, 1> resultVals;
644   for (auto result : llvm::enumerate(linalgOp->getResults())) {
645     if (!isExpanding &&
646         resultTypes[result.index()] != result.value().getType()) {
647       resultVals.push_back(rewriter.create<TensorReshapeOp>(
648           linalgOp.getLoc(), result.value().getType(),
649           fusedOp->getResult(result.index()),
650           resultReassociation[result.index()]));
651     } else {
652       resultVals.push_back(fusedOp->getResult(result.index()));
653     }
654   }
655   // Assuming a single result.
656   return resultVals;
657 }
658 
659 namespace {
660 
661 /// Pattern to fold tensor_reshape op with its consumer by using the source of
662 /// the reshape op as the operand in the consumer (instead of the result of the
663 /// tensor_reshapeop) when the tensor_reshape op is collapsing. The
664 /// corresponding index map in the consumer needs to be modified to linearize
665 /// the folded dimension.
666 ///
667 /// For example,
668 ///
669 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
670 /// %0 = linalg.tensor_reshape %arg0
671 ///        [affine_map<(i, j, k, l) -> (i)>, affine_map<(i, j, k, l) -> (j, k)>,
672 ///         affine_map<(i, j, k, l) -> (l)>]
673 ///      tensor<?x?x?xf32> into tensor<?x?x4x?xf32>
674 /// %1 = linalg.generic { indexing_maps = [#map0, #map0, #map0], ... }
675 ///        ins(%0, %arg1 : tensor<?x?x4x?xf32>, tensor<?x?x4x?xf32>) ...
676 ///        -> tensor<?x?x4x?xf32>
677 ///
678 /// can be folded into
679 ///
680 /// #map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1 * 4 + d2, d3)>
681 /// #map1 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
682 /// %0 = linalg.generic { indexing_maps = [#map0, #map1, #map1] ... }
683 ///        ins(%arg0, %arg1 : tensor<?x?x?xf32>, tensor<?x?x4x?xf32>) ...
684 ///        -> tensor<?x?x4x?xf32>
685 template <typename LinalgOpTy>
686 struct FoldProducerReshapeOpByLinearization
687     : public OpRewritePattern<LinalgOpTy> {
688   using OpRewritePattern<LinalgOpTy>::OpRewritePattern;
689 
matchAndRewrite__anon96e7b0280411::FoldProducerReshapeOpByLinearization690   LogicalResult matchAndRewrite(LinalgOpTy op,
691                                 PatternRewriter &rewriter) const override {
692     if (!op.hasTensorSemantics())
693       return failure();
694     LinalgOp linalgOp = cast<LinalgOp>(op.getOperation());
695     for (auto operand : llvm::enumerate(linalgOp.getInputs())) {
696       TensorReshapeOp reshapeOp =
697           operand.value().getDefiningOp<TensorReshapeOp>();
698       if (!reshapeOp ||
699           !isTensorReshapeOpFoldableByLinearization(
700               reshapeOp, linalgOp.getInputIndexingMap(operand.index()),
701               /*asProducer =*/true))
702         continue;
703 
704       // Compute the fused operands list,
705       SmallVector<Value, 2> fusedOperands(linalgOp.getInputs());
706       fusedOperands[operand.index()] = reshapeOp.src();
707 
708       // Compute indexing_maps for the fused operation. The indexing_maps for
709       // the operands of the consumers that arent fused are the same.
710       SmallVector<AffineMap, 4> fusedIndexMaps = llvm::to_vector<4>(
711           op.indexing_maps().template getAsValueRange<AffineMapAttr>());
712 
713       // Accepted consumer maps are either identity or permutation.
714       auto invMap = inversePermutation(fusedIndexMaps[operand.index()]);
715 
716       // Compute the indexing map to use for the result of the producer.
717       AffineMap modifiedMap =
718           linearizeCollapsedDims(invMap, reshapeOp.getResultType().getShape(),
719                                  reshapeOp.getReassociationMaps());
720       for (AffineExpr expr : modifiedMap.getResults()) {
721         if (!expr.isPureAffine())
722           return failure();
723       }
724       fusedIndexMaps[operand.index()] = modifiedMap;
725 
726       // Further check that the resulting index maps can be fused and
727       // inverted. Without this the resultant op is not legal.
728       if (!inversePermutation(concatAffineMaps(fusedIndexMaps)))
729         return op.emitRemark("fused op loop bound computation failed");
730 
731       rewriter.startRootUpdate(op);
732       op->setOperands(fusedOperands);
733       op.indexing_mapsAttr(rewriter.getAffineMapArrayAttr(fusedIndexMaps));
734       rewriter.finalizeRootUpdate(op);
735       if (reshapeOp.use_empty())
736         rewriter.eraseOp(reshapeOp);
737       return success();
738     }
739     return op.emitRemark("no fusion candidates found");
740   }
741 };
742 
743 /// Pattern to fuse a tensor_reshape op with its consumer
744 /// generic/indexed_generic op, when the reshape op is collapsing
745 /// dimensions. The dimensionality of the loop in the consumer is expanded.
746 template <typename GenericOpTy>
747 struct FoldWithProducerReshapeOpByExpansion
748     : public OpRewritePattern<GenericOpTy> {
749   using OpRewritePattern<GenericOpTy>::OpRewritePattern;
750 
matchAndRewrite__anon96e7b0280411::FoldWithProducerReshapeOpByExpansion751   LogicalResult matchAndRewrite(GenericOpTy genericOp,
752                                 PatternRewriter &rewriter) const override {
753     LinalgOp linalgOp = cast<LinalgOp>(genericOp.getOperation());
754     for (auto operand : llvm::enumerate(linalgOp.getInputs())) {
755       TensorReshapeOp reshapeOp =
756           operand.value().getDefiningOp<TensorReshapeOp>();
757       if (!reshapeOp)
758         continue;
759 
760       // Fold only if
761       // - The tensor reshape op is folding.
762       // - All constraints of fusing with reshape by expansion are met.
763       if (reshapeOp.getSrcType().getRank() <
764               reshapeOp.getResultType().getRank() ||
765           !isFusableWithReshapeByDimExpansion(linalgOp, operand.index()))
766         continue;
767 
768       Optional<SmallVector<Value, 1>> replacementValues =
769           fuseWithReshapeByExpansion(linalgOp, reshapeOp, operand.index(),
770                                      rewriter);
771       if (!replacementValues)
772         return failure();
773       rewriter.replaceOp(genericOp, replacementValues.getValue());
774       if (reshapeOp.use_empty())
775         rewriter.eraseOp(reshapeOp);
776       return success();
777     }
778     return failure();
779   }
780 };
781 
782 /// Pattern to fold tensor_reshape op with its producer. The corresponding index
783 /// map in the consumer needs to be modified to linearize the folded dimension.
784 struct FoldConsumerReshapeOpByLinearization
785     : public OpRewritePattern<TensorReshapeOp> {
786   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
787 
matchAndRewrite__anon96e7b0280411::FoldConsumerReshapeOpByLinearization788   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
789                                 PatternRewriter &rewriter) const override {
790     LinalgOp producer = reshapeOp.src().getDefiningOp<LinalgOp>();
791     if (!producer ||
792         !isa<GenericOp, IndexedGenericOp>(producer.getOperation()) ||
793         !producer.hasTensorSemantics() || producer.getNumOutputs() != 1 ||
794         !isTensorReshapeOpFoldableByLinearization(
795             reshapeOp, producer.getOutputIndexingMap(0), /*asProducer =*/false))
796       return failure();
797     // The indexing_maps for the operands of the fused operation are same as
798     // those for the operands of the producer.
799     SmallVector<AffineMap, 4> fusedIndexMaps = llvm::to_vector<4>(
800         producer.indexing_maps().getAsValueRange<AffineMapAttr>());
801 
802     auto invMap = inversePermutation(producer.getOutputIndexingMap(0));
803 
804     // Compute the indexing map to use for the operand of the producer.
805     AffineMap modifiedMap =
806         linearizeCollapsedDims(invMap, reshapeOp.getSrcType().getShape(),
807                                reshapeOp.getReassociationMaps());
808     for (AffineExpr expr : modifiedMap.getResults()) {
809       if (!expr.isPureAffine())
810         return reshapeOp.emitRemark("fused op indexing map is not affine");
811     }
812     fusedIndexMaps.back() = modifiedMap;
813 
814     // Further check that the resulting index maps can be fused and
815     // inverted. Without this the resultant op is not legal.
816     if (!inversePermutation(concatAffineMaps(fusedIndexMaps)))
817       return reshapeOp.emitRemark("fused op loop bound computation failed");
818 
819     LinalgOp fusedOp = createLinalgOpOfSameType(
820         producer, rewriter, rewriter.getUnknownLoc(), reshapeOp.getResultType(),
821         /*inputs=*/producer.getInputs(),
822         /*outputBuffers=*/ValueRange{},
823         /*initTensors=*/ValueRange{}, // no init tensors for now.
824         rewriter.getAffineMapArrayAttr(fusedIndexMaps),
825         producer.iterator_types(),
826         /*doc=*/nullptr,
827         /*library_call=*/nullptr,
828         /*sparse=*/nullptr);
829     auto &fusedRegion = fusedOp->getRegion(0);
830     rewriter.cloneRegionBefore(producer->getRegion(0), fusedRegion,
831                                fusedRegion.begin());
832     rewriter.replaceOp(reshapeOp, fusedOp->getResults());
833     if (producer.use_empty())
834       rewriter.eraseOp(producer);
835     return success();
836   }
837 };
838 
839 /// Pattern to fold a tensor_reshape op with its producer generic op if the
840 /// tensor_reshape op is expanding, by expanding the dimensionality of the loop
841 /// in the producer op.
842 struct FoldReshapeWithGenericOpByExpansion
843     : public OpRewritePattern<TensorReshapeOp> {
844   using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
matchAndRewrite__anon96e7b0280411::FoldReshapeWithGenericOpByExpansion845   LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
846                                 PatternRewriter &rewriter) const override {
847     // Fold only if
848     // - The tensor reshape op is a expanding case.
849     // - All constraints of fusing with reshape by expansion are met.
850     if (reshapeOp.getSrcType().getRank() > reshapeOp.getResultType().getRank())
851       return failure();
852     LinalgOp producer = reshapeOp.src().getDefiningOp<LinalgOp>();
853     if (!producer || producer.getNumOutputs() != 1 ||
854         !isFusableWithReshapeByDimExpansion(producer, producer.getNumInputs()))
855       return failure();
856     Optional<SmallVector<Value, 1>> replacementValues =
857         fuseWithReshapeByExpansion(producer, reshapeOp, producer.getNumInputs(),
858                                    rewriter);
859     if (!replacementValues)
860       return failure();
861     rewriter.replaceOp(reshapeOp, replacementValues.getValue());
862     if (producer.use_empty())
863       rewriter.eraseOp(producer);
864     return success();
865   }
866 };
867 
868 /// Pattern to fold a GenericOp/IndexedGenericOp with a splat constant.
869 template <typename LinalgOpTy>
870 struct FoldSplatConstants : public OpRewritePattern<LinalgOpTy> {
871   using OpRewritePattern<LinalgOpTy>::OpRewritePattern;
872 
matchAndRewrite__anon96e7b0280411::FoldSplatConstants873   LogicalResult matchAndRewrite(LinalgOpTy op,
874                                 PatternRewriter &rewriter) const override {
875     if (!op.hasTensorSemantics())
876       return failure();
877     LinalgOp linalgOp = cast<LinalgOp>(op.getOperation());
878     for (auto operand : llvm::enumerate(linalgOp.getInputs())) {
879       ConstantOp constantOp = operand.value().getDefiningOp<ConstantOp>();
880       if (!constantOp ||
881           !constantOp.value().cast<DenseElementsAttr>().isSplat())
882         continue;
883 
884       // The indexing_maps for the operands of the fused operation are same as
885       // those for the operands of the linalgOp without the indexing map at
886       // operand.index()
887       SmallVector<AffineMap, 4> fusedIndexMaps = llvm::to_vector<4>(
888           linalgOp.indexing_maps().getAsValueRange<AffineMapAttr>());
889       fusedIndexMaps.erase(std::next(fusedIndexMaps.begin(), operand.index()));
890 
891       // The operands list is same as the linalgOp with the argument for
892       // constant index dropped.
893       SmallVector<Value, 4> fusedOperands(linalgOp.getInputs());
894       fusedOperands.erase(std::next(fusedOperands.begin(), operand.index()));
895 
896       // Create a constant scalar value from the splat constant.
897       Value scalarConstant = rewriter.create<ConstantOp>(
898           constantOp.getLoc(),
899           constantOp.value().cast<DenseElementsAttr>().getSplatValue());
900 
901       LinalgOp fusedOp = createLinalgOpOfSameType(
902           linalgOp, rewriter, rewriter.getUnknownLoc(),
903           linalgOp->getResultTypes(),
904           /*inputs=*/fusedOperands,
905           /*outputBuffers=*/ValueRange{},
906           /*initTensors=*/ValueRange{}, // no init tensors for now.
907           rewriter.getAffineMapArrayAttr(fusedIndexMaps),
908           linalgOp.iterator_types(),
909           /*doc=*/nullptr,
910           /*library_call=*/nullptr,
911           /*sparse=*/nullptr);
912 
913       // Map the block argument corresponding to the replaced argument with the
914       // scalar constant.
915       Region &linalgOpRegion = linalgOp->getRegion(0);
916       Block &entryBlock = *linalgOpRegion.begin();
917       unsigned argIndex = entryBlock.getNumArguments() -
918                           linalgOp.getNumInputs() + operand.index();
919       BlockAndValueMapping mapping;
920       mapping.map(entryBlock.getArgument(argIndex), scalarConstant);
921       Region &fusedRegion = fusedOp->getRegion(0);
922       rewriter.cloneRegionBefore(linalgOpRegion, fusedRegion,
923                                  fusedRegion.begin(), mapping);
924       rewriter.replaceOp(linalgOp, fusedOp->getResults());
925       if (constantOp.use_empty())
926         rewriter.eraseOp(constantOp);
927       return success();
928     }
929     return failure();
930   }
931 };
932 } // namespace
933 
934 Optional<SmallVector<Value, 1>>
fuseTensorOps(PatternRewriter & rewriter,Operation * consumer,unsigned consumerIdx)935 mlir::linalg::fuseTensorOps(PatternRewriter &rewriter, Operation *consumer,
936                             unsigned consumerIdx) {
937   if (consumerIdx >= consumer->getNumOperands())
938     return llvm::None;
939   Operation *producer = consumer->getOperand(consumerIdx).getDefiningOp();
940   if (!producer || producer->getNumResults() != 1)
941     return llvm::None;
942 
943   // Fuse when consumer is GenericOp or IndexedGenericOp.
944   if (!isa<GenericOp, IndexedGenericOp>(consumer) ||
945       !isa<GenericOp, IndexedGenericOp>(producer))
946     return llvm::None;
947 
948   return fuseTensorOpsImpl(cast<LinalgOp>(producer), cast<LinalgOp>(consumer),
949                            consumerIdx, rewriter);
950 }
951 
952 namespace {
953 /// Patterns to fuse a generic op, with the producer of its operands.
954 template <typename LinalgOpTy>
955 struct FuseTensorOps : public OpRewritePattern<LinalgOpTy> {
956   using OpRewritePattern<LinalgOpTy>::OpRewritePattern;
957 
matchAndRewrite__anon96e7b0280511::FuseTensorOps958   LogicalResult matchAndRewrite(LinalgOpTy op,
959                                 PatternRewriter &rewriter) const override {
960     // Find the first operand that is defined by another generic op on tensors.
961     for (auto operandNum : llvm::seq<unsigned>(0, op->getNumOperands())) {
962       Operation *producer = op->getOperand(operandNum).getDefiningOp();
963       if (!producer)
964         continue;
965       Optional<SmallVector<Value, 1>> fusedOpResults =
966           fuseTensorOps(rewriter, op, operandNum);
967       if (fusedOpResults) {
968         rewriter.replaceOp(op, *fusedOpResults);
969         if (producer->use_empty())
970           rewriter.eraseOp(producer);
971         return success();
972       }
973     }
974     return failure();
975   }
976 };
977 
978 /// Pass that fuses generic ops on tensors. Used only for testing.
979 struct FusionOfTensorOpsPass
980     : public LinalgFusionOfTensorOpsBase<FusionOfTensorOpsPass> {
runOnOperation__anon96e7b0280511::FusionOfTensorOpsPass981   void runOnOperation() override {
982     OwningRewritePatternList patterns;
983     Operation *op = getOperation();
984     populateLinalgTensorOpsFusionPatterns(op->getContext(), patterns);
985     applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
986   }
987 };
988 
989 /// Pass to test folding of reshape op with generic/indexed_generic ops by
990 /// linearization.
991 struct FoldReshapeOpsByLinearizationPass
992     : public LinalgFoldReshapeOpsByLinearizationBase<
993           FoldReshapeOpsByLinearizationPass> {
runOnOperation__anon96e7b0280511::FoldReshapeOpsByLinearizationPass994   void runOnOperation() override {
995     OwningRewritePatternList patterns;
996     Operation *op = getOperation();
997     populateFoldReshapeOpsByLinearizationPatterns(op->getContext(), patterns);
998     applyPatternsAndFoldGreedily(op->getRegions(), std::move(patterns));
999   }
1000 };
1001 
1002 } // namespace
1003 
populateFoldReshapeOpsByLinearizationPatterns(MLIRContext * context,OwningRewritePatternList & patterns)1004 void mlir::populateFoldReshapeOpsByLinearizationPatterns(
1005     MLIRContext *context, OwningRewritePatternList &patterns) {
1006   patterns.insert<FoldProducerReshapeOpByLinearization<GenericOp>,
1007                   FoldProducerReshapeOpByLinearization<IndexedGenericOp>,
1008                   FoldConsumerReshapeOpByLinearization>(context);
1009 }
1010 
populateFoldReshapeOpsByExpansionPatterns(MLIRContext * context,OwningRewritePatternList & patterns)1011 void mlir::populateFoldReshapeOpsByExpansionPatterns(
1012     MLIRContext *context, OwningRewritePatternList &patterns) {
1013   patterns.insert<FoldReshapeWithGenericOpByExpansion,
1014                   FoldWithProducerReshapeOpByExpansion<GenericOp>,
1015                   FoldWithProducerReshapeOpByExpansion<IndexedGenericOp>>(
1016       context);
1017 }
1018 
populateLinalgTensorOpsFusionPatterns(MLIRContext * context,OwningRewritePatternList & patterns)1019 void mlir::populateLinalgTensorOpsFusionPatterns(
1020     MLIRContext *context, OwningRewritePatternList &patterns) {
1021   patterns.insert<FuseTensorOps<GenericOp>, FuseTensorOps<IndexedGenericOp>,
1022                   FoldSplatConstants<GenericOp>,
1023                   FoldSplatConstants<IndexedGenericOp>>(context);
1024   populateFoldReshapeOpsByExpansionPatterns(context, patterns);
1025   GenericOp::getCanonicalizationPatterns(patterns, context);
1026   IndexedGenericOp::getCanonicalizationPatterns(patterns, context);
1027   TensorReshapeOp::getCanonicalizationPatterns(patterns, context);
1028 }
1029 
createLinalgFusionOfTensorOpsPass()1030 std::unique_ptr<Pass> mlir::createLinalgFusionOfTensorOpsPass() {
1031   return std::make_unique<FusionOfTensorOpsPass>();
1032 }
1033 
createFoldReshapeOpsByLinearizationPass()1034 std::unique_ptr<Pass> mlir::createFoldReshapeOpsByLinearizationPass() {
1035   return std::make_unique<FoldReshapeOpsByLinearizationPass>();
1036 }
1037